diff --git a/build_variables.bzl b/build_variables.bzl index bb5cbca0d6914..5cf9017d9af6e 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -32,6 +32,7 @@ libtorch_nvfuser_runtime_sources = [ "torch/csrc/jit/codegen/cuda/runtime/helpers.cu", "torch/csrc/jit/codegen/cuda/runtime/index_utils.cu", "torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu", + "torch/csrc/jit/codegen/cuda/runtime/swizzle.cu", "torch/csrc/jit/codegen/cuda/runtime/memory.cu", "torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu", "torch/csrc/jit/codegen/cuda/runtime/tensor.cu", @@ -643,6 +644,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/autograd/functions/comm.cpp", "torch/csrc/jit/codegen/cuda/arith.cpp", "torch/csrc/jit/codegen/cuda/compute_at.cpp", + "torch/csrc/jit/codegen/cuda/inline_propagator.cpp", "torch/csrc/jit/codegen/cuda/compute_at_map.cpp", "torch/csrc/jit/codegen/cuda/codegen.cpp", "torch/csrc/jit/codegen/cuda/contiguity.cpp", @@ -658,7 +660,6 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/grouped_reduction.cpp", "torch/csrc/jit/codegen/cuda/index_compute.cpp", "torch/csrc/jit/codegen/cuda/lower_index_compute.cpp", - "torch/csrc/jit/codegen/cuda/index_reference_replay.cpp", "torch/csrc/jit/codegen/cuda/instrumentation.cpp", "torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp", "torch/csrc/jit/codegen/cuda/ir_builder.cpp", diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index 2c2c312e6ebd7..43dd5151c9abc 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -83,6 +83,11 @@ def is_pre_volta(): TEST_BF16 = RUN_NVFUSER and torch.cuda.is_bf16_supported() +TEST_LARGE_TENSOR = RUN_NVFUSER +if RUN_NVFUSER: + torch.ones(1).cuda() # initialize cuda context + TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9 + class CudaFuserTestOptions(): def __init__(self): self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() @@ -184,23 +189,27 @@ def tearDown(self): self.cuda_fuser_options.restore() super(TestCudaFuser, self).tearDown() - def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1): - torch.cuda.manual_seed_all(123) - jit_o = jit_op(*args) - torch.cuda.manual_seed_all(123) + def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1, check_runs=1): + seed = 123 + torch.cuda.manual_seed_all(seed) jit_o = jit_op(*args) - torch.cuda.manual_seed_all(123) - o = op(*args) - if type(jit_o) is torch.Tensor: - jit_o = [jit_o, ] - o = [o, ] + for i in range(check_runs): + torch.cuda.manual_seed_all(seed + i) + jit_o = jit_op(*args) + torch.cuda.manual_seed_all(seed + i) + o = op(*args) + + if type(jit_o) is torch.Tensor: + jit_o = [jit_o, ] + o = [o, ] + + for oo, jit_oo in zip(o, jit_o): + self.assertEqual(oo.dtype, jit_oo.dtype) + self.assertEqual(oo, jit_oo) + if check_stride: + self.assertEqual(oo.stride(), jit_oo.stride()) - for oo, jit_oo in zip(o, jit_o): - self.assertEqual(oo.dtype, jit_oo.dtype) - self.assertEqual(oo, jit_oo) - if check_stride: - self.assertEqual(oo.stride(), jit_oo.stride()) self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, num_fusion, consider_subgraphs=True) def _run_training_helper(self, jit_op, op, grads, *args): @@ -2563,13 +2572,14 @@ def t(x: torch.Tensor, p: float, train: bool): self._run_helper(t_jit, t, x, 0.15, False) + @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") def test_dropout_train_nograd_fusion(self): dtype = torch.float device = "cuda" - x = torch.randn([10, 4, 8], dtype=dtype, device=device) + x = torch.randn([64, 128, 1024], dtype=dtype, device=device) def t(x: torch.Tensor, p: float, train: bool): o = torch.nn.functional.dropout(x, p, training=train) @@ -2578,7 +2588,8 @@ def t(x: torch.Tensor, p: float, train: bool): t_jit = torch.jit.script(t) - self._run_helper(t_jit, t, x, 0.0, True) + self._run_helper(t_jit, t, x, 0.0, True, check_runs=20) + self._run_helper(t_jit, t, x, 1.0, True, check_runs=20) @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, @@ -4391,6 +4402,33 @@ def t(x): t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x) + + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_to_copy(self): + x = torch.randn(4, 2, device="cuda") + + with nvfuser_singleton_fusion(True): + def t(x, dtype : torch.dtype): + o = torch.ops.aten._to_copy(x, dtype=dtype) + return o + + t.__disable_jit_function_caching__ = True + + t_jit = torch.jit.script(t) + for dtype in [torch.float16, torch.bool, torch.float64]: + self._run_helper(t_jit, t, x, dtype) + + def t_none(x): + with torch.jit.strict_fusion(): + o = torch.ops.aten._to_copy(x, dtype=None) + return o + + t_jit_none = torch.jit.script(t_none) + self._run_helper(t_jit_none, t_none, x) + + @unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since reshape is disabled now") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, @@ -4752,6 +4790,35 @@ def t(x): jit_t = torch.jit.script(t) self._run_helper(jit_t, t, x) + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_issue_1785(self): + class Fusion(torch.nn.Module): + def __init__(self): + super(Fusion, self).__init__() + + def forward(self, x, a, b): + out = torch.mul(x.unsqueeze(-1), a) + out = out + b + return out + + x = torch.randn(1024, 192, 3, device='cuda') + a = torch.randn(3, 128, device='cuda') + b = torch.randn(3, 128, device='cuda') + + model = Fusion() + jit_model = torch.jit.script(model) + + with torch.jit.fuser('fuser2'): + for _ in range(4): + out_ref = model(x, a, b) + out_jit = jit_model(x, a, b) + + out_ref = model(x, a, b) + out_jit = jit_model(x, a, b) + self.assertTrue(self._compare("comparing output failed", out_ref, out_jit, 1e-5)) + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") diff --git a/torch/csrc/jit/codegen/cuda/README.md b/torch/csrc/jit/codegen/cuda/README.md index 15a9a167d1973..be8aed6c5ce44 100644 --- a/torch/csrc/jit/codegen/cuda/README.md +++ b/torch/csrc/jit/codegen/cuda/README.md @@ -187,7 +187,7 @@ There're a few debug dump that could be turned on via environment variables. Loo 1. `dump_eff_bandwidth`: print out effective bandwidth of each generated kernel. This naively measure the kernel time divided by I/O buffer size and is a good/simple metric of performance for bandwidth bound kernels 2. `cuda_kernel`: print out generated cuda kernels 3. `launch_param`: print out launch config of generated kernels -4. `print_args`: print out input output tensors of executed codegen kernels +4. `kernel_args`: print out input/output/buffer tensors of all executed codegen kernels, note that for buffers, we indicate whether they are zero-initialized, which hints on an extra kernel to fill the tensor before codegen kernels. ### FAQs diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 175e6a62e4af5..2907652b7e6db 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -458,7 +458,6 @@ TensorView* unaryOp( } NVFUSER_DEFINE_UNARY_OP(set, Set) -NVFUSER_DEFINE_UNARY_OP(randlike, RandLike) NVFUSER_DEFINE_UNARY_OP(ceil, Ceil) NVFUSER_DEFINE_UNARY_OP(floor, Floor) NVFUSER_DEFINE_UNARY_OP(frac, Frac) @@ -469,6 +468,30 @@ NVFUSER_DEFINE_UNARY_OP(silu, Silu) NVFUSER_DEFINE_UNARY_OP(trunc, Trunc) #undef NVFUSER_DEFINE_UNARY_OP +Val* randlike(Val* v) { + TORCH_CHECK( + isFloatingPointType(v->dtype()), + "input must have floating point type, but got ", + v->dtype()); + auto rand_vals = unaryOp(UnaryOpType::RandLike, v); + return where( + eq(rand_vals, IrBuilder::create(1.0)), + IrBuilder::create(0.0), + rand_vals); +} + +TensorView* randlike(TensorView* v) { + TORCH_CHECK( + isFloatingPointType(v->dtype()), + "input must have floating point type, but got ", + v->dtype()); + auto rand_vals = unaryOp(UnaryOpType::RandLike, v); + return where( + eq(rand_vals, IrBuilder::create(1.0)), + IrBuilder::create(0.0), + rand_vals); +} + Val* bitwise_not(Val* v) { TORCH_CHECK( isIntegralType(v->dtype()) || v->dtype() == DataType::Bool, diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 5d5f53a260421..9b446de9d636a 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -114,6 +114,38 @@ std::string genCall( return ss.str(); } +//! A utility class to check if an expression of a particular type exists +class ExprFinder : kir::ConstIrVisitor { + public: + //! True if expr or any of its nested expressions is included in + //! expr_types + static bool exists( + const Expr* expr, + const std::unordered_set& expr_types) { + ExprFinder finder(expr_types); + finder.handle(std::vector{expr}); + return finder.is_found_; + } + + private: + ExprFinder(const std::unordered_set& expr_types) + : expr_types_(expr_types) {} + + using kir::ConstIrVisitor::handle; + + void handle(const Expr* expr) final { + if (expr_types_.find(expr->etype()) != expr_types_.end()) { + is_found_ = true; + return; + } + kir::ConstIrVisitor::handle(expr); + } + + private: + const std::unordered_set& expr_types_; + bool is_found_ = false; +}; + class CudaKernelGenerator : private OptOutConstDispatch { static constexpr const char* kTab = " "; @@ -132,7 +164,15 @@ class CudaKernelGenerator : private OptOutConstDispatch { } private: - explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) {} + explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) { + initStringStreamFormat(code_); + } + + void initStringStreamFormat(std::stringstream& ss) { + const int digits = std::numeric_limits::max_digits10; + ss.imbue(std::locale("C")); + ss << std::scientific << std::setprecision(digits); + } // Generates the kernel function declaration void genDeclaration(const std::string& kernel_name) { @@ -320,6 +360,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { std::string gen(const Statement* stmt) { std::stringstream tmp_code; + initStringStreamFormat(tmp_code); std::swap(tmp_code, code_); OptOutConstDispatch::handle(stmt); std::swap(tmp_code, code_); @@ -330,6 +371,8 @@ class CudaKernelGenerator : private OptOutConstDispatch { std::stringstream name; if (val->isA()) { name << "T"; + } else if (val->isA()) { + name << "ip"; } else { name << typePrefix(val->dtype()); } @@ -381,9 +424,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { } else if (std::isnan(val)) { code_ << "NAN"; } else { - const int digits = - std::numeric_limits::max_digits10; - code_ << std::setprecision(digits) << val; + code_ << val; } } else { code_ << varName(d); @@ -391,6 +432,14 @@ class CudaKernelGenerator : private OptOutConstDispatch { } void handle(const Int* i) final { + // Check the replacement map first. If there's an entry for i, use + // the corresponding replacement. + auto replace_it = index_replacement_map_.find(i); + if (replace_it != index_replacement_map_.end()) { + code_ << replace_it->second; + return; + } + const auto def = i->definition(); const bool has_alloc = alloc_map_.find(i) != alloc_map_.end(); if (def != nullptr && !has_alloc) { @@ -408,9 +457,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { if (def != nullptr && !has_alloc) { code_ << "(" << gen(def) << ")"; } else if (c->isConst()) { - const int digits = std::numeric_limits::max_digits10; - code_ << "std::complex" << std::setprecision(digits) - << *c->value(); + code_ << "std::complex" << *c->value(); } else { code_ << varName(c); } @@ -1529,7 +1576,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { } TORCH_INTERNAL_ASSERT( - grouped_grop->numReductions() == 2, + grouped_grop->numExprs() == 2, "Only grouping of 2 reductions is supported. ", grouped_grop->toString()); @@ -1548,7 +1595,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { ArgumentBuilder func_args(block_nest_level_ + 1, kTab); // Append arguments for each reduction - for (const auto i : c10::irange(grouped_grop->numReductions())) { + for (const auto i : c10::irange(grouped_grop->numExprs())) { TORCH_INTERNAL_ASSERT( grouped_grop->reduction_buffers().at(i)->buffer()->isA()); const auto work_buffer = @@ -1590,17 +1637,106 @@ class CudaKernelGenerator : private OptOutConstDispatch { indent() << kTab << func_args << ");\n"; } + // Enumerates all combinations of index values of grouped + // loops. Each combination is a vector of loop index values. The + // length of the vector is the number of grouped loops. + // + // Example 1: only one domain of extent 2 is grouped: {{0}, {1}}. + // Example 2: two domains of extents 2 and 3 are grouped: {{0, 0}, + // {0, 1}, {0, 2}, {1, 0}, {1, 1}, {1, 2}} + std::vector> getGroupedLoopIndexConcreteIntSets() { + std::vector> index_combinationsatoins; + + // Initialize with an empty vector + index_combinationsatoins.push_back(std::vector()); + + // Incrementally build a combinatorial set + for (const auto loop : grouped_loops_) { + const auto iter_count = loop->stop()->evaluateInt(); + std::vector> new_combinations; + // Append integers from 0 to iter_count to all the vectors built + // so far + for (const auto& index_vec : index_combinationsatoins) { + for (int64_t i = 0; i < iter_count; ++i) { + auto index_vec_appended = index_vec; + index_vec_appended.push_back(i); + new_combinations.push_back(index_vec_appended); + } + } + index_combinationsatoins = std::move(new_combinations); + } + + return index_combinationsatoins; + } + + //! Returns all combinations of maps from index Vals of grouped loops to their + //! conrete integers. + std::vector> + getLoopIndexReplacementMaps() { + std::vector> maps; + + if (grouped_loops_.empty()) { + std::unordered_map empty_map; + return {empty_map}; + } + + // Vector of indices of grouped loops + std::vector loop_indices; + std::transform( + grouped_loops_.begin(), + grouped_loops_.end(), + std::back_inserter(loop_indices), + [](const kir::ForLoop* loop) { return loop->index()->as(); }); + + // All combinations of loop index integer values + const auto index_val_sets = getGroupedLoopIndexConcreteIntSets(); + + // Create maps from loop index Vals to integers + for (const auto& index_values : index_val_sets) { + TORCH_INTERNAL_ASSERT(loop_indices.size() == index_values.size()); + std::unordered_map index_val_map; + for (const auto i : c10::irange(loop_indices.size())) { + auto loop_index = loop_indices.at(i); + auto index_val = index_values.at(i); + index_val_map.emplace(loop_index, index_val); + } + maps.emplace_back(std::move(index_val_map)); + } + + return maps; + } + void generateGroupedGridAllreduce( const kir::GroupedGridReduction* grouped_grop) { TORCH_INTERNAL_ASSERT(grouped_grop->isAllreduce()); - constexpr int max_num_reductions = 8; + // There are two dimensions of grouping: horizontal grouping and + // iteration grouping. The total number of individual reductions + // is the number of horizontal reductions * the extent of grouped + // iterations. All of them are packed into a single grid reduction + // call. The number of reductions is limited, and currently it is + // simply an error if exceeded. This could be avoided by + // decomposing grouped_grop into smaller groups within the + // limit. TODO: Support a larger number of reductions. + + // First, enumerate all combinations of loop index values of + // grouped IterDomains. If only a single domain is grouped, this + // is simply just a 1D vector of integer from 0 to extent-1. If + // two domains are grouped, combinations of two integer vectors + // are returned. These loop index value vectors are returned as a + // map from loop index Vals to concrete int values. + const auto index_replacement_maps = getLoopIndexReplacementMaps(); + const auto num_grouped_iterations = index_replacement_maps.size(); + + // This is also checked at the lowering validaiton time, so it + // isn't strictly necessary. TORCH_INTERNAL_ASSERT( - grouped_grop->numReductions() <= max_num_reductions, + num_grouped_iterations * grouped_grop->numExprs() <= + kMaxNumGroupedReductions, "Too many grouped reductions: ", grouped_grop->toString(), ". Up to ", - max_num_reductions, + kMaxNumGroupedReductions, " reductions are allowed."); ArgumentBuilder types; @@ -1614,44 +1750,65 @@ class CudaKernelGenerator : private OptOutConstDispatch { ArgumentBuilder read_preds; ArgumentBuilder write_preds; - for (const auto i : c10::irange(grouped_grop->numReductions())) { - const auto data_type = grouped_grop->outputs().at(i)->dtype(); - TORCH_INTERNAL_ASSERT( - grouped_grop->reduction_buffers().at(i)->buffer()->isA()); - - types.arg(data_type); - - // out - outputs.arg(gen(grouped_grop->outputs().at(i))); - - // inp - inputs.arg(gen(grouped_grop->inputs().at(i))); + for (const auto expr_index : c10::irange(grouped_grop->numExprs())) { + const auto data_type = grouped_grop->outputs().at(expr_index)->dtype(); + TORCH_INTERNAL_ASSERT(grouped_grop->reduction_buffers() + .at(expr_index) + ->buffer() + ->isA()); - // global_work_buffer - const auto work_buffer = - grouped_grop->reduction_buffers().at(i)->buffer()->as(); - work_bufs.arg("&").append(varName(work_buffer)).append("[0]"); + for (const auto& group_index : + c10::irange(index_replacement_maps.size())) { + // Set the index replacement map with the concrete values of + // indices of grouped loops. + index_replacement_map_ = index_replacement_maps.at(group_index); - init_vals.arg(genInline(grouped_grop->initVal(i))); + types.arg(data_type); - reduction_ops.arg(genReductionOp( - grouped_grop->getReductionOpType(i), - grouped_grop->output(i)->dtype())); + // out + outputs.arg(gen(grouped_grop->outputs().at(expr_index))); + + // inp + inputs.arg(gen(grouped_grop->inputs().at(expr_index))); + + // global_work_buffer + const auto work_buffer = grouped_grop->reduction_buffers() + .at(expr_index) + ->buffer() + ->as(); + // Separate Work buffer is used for each reduction. + auto work_buffer_offset = group_index == 0 + ? "0" + : (genInline(grouped_grop->buffer_stride()) + " * " + + std::to_string(group_index)); + work_bufs.arg("&") + .append(varName(work_buffer)) + .append("[") + .append(work_buffer_offset) + .append("]"); + init_vals.arg(genInline(grouped_grop->initVal(expr_index))); + + reduction_ops.arg(genReductionOp( + grouped_grop->getReductionOpType(expr_index), + grouped_grop->output(expr_index)->dtype())); + + // read and write predicates + bool_types.arg("bool"); + // Same argument for all inputs. Different predicates would be + // used when grouping is done across iterations + TORCH_INTERNAL_ASSERT( + grouped_grop->predicate() != nullptr && + grouped_grop->predicate()->hasValue()); + const auto read_pred = genInline(grouped_grop->predicate()); + read_preds.arg(read_pred); + if (grouped_grop->writePredicate() != nullptr) { + TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue()); + write_preds.arg(genInline(grouped_grop->writePredicate())); + } else { + write_preds.arg(read_pred); + } - // read and write predicates - bool_types.arg("bool"); - // Same argument for all inputs. Different predicates would be - // used when grouping is done across iterations - TORCH_INTERNAL_ASSERT( - grouped_grop->predicate() != nullptr && - grouped_grop->predicate()->hasValue()); - const auto read_pred = genInline(grouped_grop->predicate()); - read_preds.arg(read_pred); - if (grouped_grop->writePredicate() != nullptr) { - TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue()); - write_preds.arg(genInline(grouped_grop->writePredicate())); - } else { - write_preds.arg(read_pred); + index_replacement_map_.clear(); } } @@ -1969,7 +2126,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { void handleTrivialLoop(const kir::ForLoop* loop) { if (loop->vectorize()) { - vectorize_scope_ = loop->vectorize(); + vectorize_scope_ = true; } handleScope(loop->body()); if (loop->vectorize()) { @@ -1978,7 +2135,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { } void handle(const GroupedReductionOp* grouped_rop) final { - for (const auto i : c10::irange(grouped_rop->numReductions())) { + for (const auto i : c10::irange(grouped_rop->numExprs())) { TORCH_INTERNAL_ASSERT(grouped_rop->output(i)->isA()); const auto output = grouped_rop->output(i)->as(); @@ -1991,7 +2148,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { TORCH_INTERNAL_ASSERT( !has_grid_reduce, - "GroupedReductionOp does not support block parallelization. GroupedGridReductionOp must be used. ", + "GroupedReductionOp does not support block parallelization. GroupedGridReduction must be used. ", grouped_rop->toString()); if (!has_block_reduce) { @@ -2017,12 +2174,32 @@ class CudaKernelGenerator : private OptOutConstDispatch { } } + //! True if loop is grouped. The IterDomain of the loop must have + //! ParallelType::Group, but it isn't sufficient as the loop may be + //! for an initialization expression, for which the loop shold not + //! be grouped. Make sure a GroupedGridReduction is found. + bool isGroupedLoop(const kir::ForLoop* loop) { + if (loop->iter_domain()->getParallelType() != ParallelType::Group) { + return false; + } + return ExprFinder::exists(loop, {ExprType::GroupedGridReduction}); + } + void handle(const kir::ForLoop* loop) final { if (loop->isTrivial()) { handleTrivialLoop(loop); return; } + // If a loop is grouped, no loop is created, but it isn't + // considered trivial as the loop trip count is not one. + if (isGroupedLoop(loop)) { + grouped_loops_.push_back(loop); + handleScope(loop->body()); + grouped_loops_.pop_back(); + return; + } + const auto gen_index = gen(loop->index()); const auto gen_start = genInline(loop->start()); const auto gen_stop = genInline(loop->stop()); @@ -2198,6 +2375,52 @@ class CudaKernelGenerator : private OptOutConstDispatch { indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n"; } + void handle(const kir::Swizzle2DInt* swizzle_2d) { + TORCH_INTERNAL_ASSERT(print_inline_); + TORCH_INTERNAL_ASSERT( + swizzle_2d->swizzleType() != Swizzle2DType::NoSwizzle, + "Swizzle type undefined."); + if (print_inline_) { + code_ << swizzle_2d->swizzleType() << "({" << gen(swizzle_2d->inX()) + << "," << gen(swizzle_2d->inY()) << "} , " + << "{" << gen(swizzle_2d->extentX()) << "," + << gen(swizzle_2d->extentY()) << "})"; + } + } + + void handle(const kir::IntPair* int_pair) { + const auto def = int_pair->definition(); + if (print_inline_) { + code_ << gen(def); + } else { + code_ << varName(int_pair); + } + } + + void handle(const kir::PairSelect* pair_select) { + if (print_inline_) { + code_ << gen(pair_select->in()); + } else { + indent() << gen(pair_select->out()) << " = " << gen(pair_select->in()); + } + + switch (pair_select->selection()) { + case kir::PairSelect::Selection::X: + code_ << ".x"; + break; + case kir::PairSelect::Selection::Y: + code_ << ".y"; + break; + default: + TORCH_INTERNAL_ASSERT(false, "unknown select") + break; + } + + if (!print_inline_) { + code_ << ";\n"; + } + } + private: std::stringstream code_; const kir::Kernel* kernel_; @@ -2207,10 +2430,13 @@ class CudaKernelGenerator : private OptOutConstDispatch { // Mark when we are inside of a vectorized for-loop bool vectorize_scope_ = false; - //! Keep track of Allocate node for Val. Used to determine if Val //! should be inlined. std::unordered_map alloc_map_; + //! Keep track of grouped loops + std::deque grouped_loops_; + //! Used to replace symbolic indices with concrete values + std::unordered_map index_replacement_map_; }; } // namespace diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index 5603d5e54d577..0c7df3354da43 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include @@ -41,292 +40,9 @@ std::deque> tvChains( return tv_chains; } -bool validateDomain(TensorView* tv, TensorDomain* new_td) { - auto first_mismatch = - BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td); - return first_mismatch >= (int)tv->getMaxProducerPosition() && - first_mismatch >= (int)tv->getComputeAtPosition(); -} - -// Return the max position in consumer that producer can be inlined to -// Cannot inline: -// Reduction dimensions in producer -// Block broadcast dimensions in producer -// Vectorized dimensions in producer or consumer -// Unrolled dimensions in producer or consumer -// Dimensions derived from root dimensions that exist in both but are -// unmappable -unsigned int getReplayablePosPasC( - TensorView* producer, - TensorView* consumer, - const std::unordered_set& unmappable_producer_dims, - ComputeAtMode mode) { - // Check if any consumer dimensions are marked as vectorize as producer can - // not be inlined to vectorized dimensions in consumer. - auto c_dom = consumer->domain()->domain(); - auto vector_dim_it = - std::find_if(c_dom.begin(), c_dom.end(), [&mode](IterDomain* id) { - return isParallelTypeVectorize(id->getParallelType()) || - ((mode == ComputeAtMode::BestEffort || - mode == ComputeAtMode::MostInlined) && - id->getParallelType() == ParallelType::Unroll); - }); - - // Limit max position based on vectorized dims in consumer. - auto max_consumer_pos = std::distance(c_dom.begin(), vector_dim_it); - - auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); - auto c2p_root_map = - PairwiseRootDomainMap(producer, consumer) - .mapConsumerToProducer(consumer->domain(), producer->domain()); - - auto replay_PasC = - BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_root_map); - - // Look for id's that map to a consumer id that's vectorized - auto c2p_replay_map = replay_PasC.getReplay(); - - for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; - consumer_pos--) { - auto map_it = c2p_replay_map.find(consumer->axis((int)consumer_pos - 1)); - if (map_it != c2p_replay_map.end()) { - auto p_id = map_it->second; - // If we find a consumer dim that maps to a producer dim that's - // vectorized or unrolled limit max compute at by it. - if (isParallelTypeVectorize(p_id->getParallelType()) || - ((mode == ComputeAtMode::BestEffort || - mode == ComputeAtMode::MostInlined) && - p_id->getParallelType() == ParallelType::Unroll)) { - max_consumer_pos = consumer_pos - 1; - } - } - } - - // Start at max position and work backwards, try to find a location where - // producer can be inlined. - for (size_t consumer_pos = max_consumer_pos; consumer_pos > 0; - consumer_pos--) { - // Grab all root dimensions of consumer as roots must be used to understand - // inlining potential. - auto consumer_root_dim_vals = - IterVisitor::getInputsTo({c_dom.begin(), c_dom.begin() + consumer_pos}); - // convert to iter domains - auto consumer_root_dim_ids = - ir_utils::filterByType(consumer_root_dim_vals); - // If any root dimensions cannot be mapped to producer we can't inline. If - // any root dimension - if (std::any_of( - consumer_root_dim_ids.begin(), - consumer_root_dim_ids.end(), - [&unmappable_producer_dims, &c2p_root_map](IterDomain* c_root_id) { - auto p_root_id_it = c2p_root_map.find(c_root_id); - if (p_root_id_it == c2p_root_map.end()) { - return false; - } - auto p_id = p_root_id_it->second; - return unmappable_producer_dims.find(p_id) != - unmappable_producer_dims.end(); - })) { - continue; - } - return consumer_pos; - } - - return 0; -} - -// Return the max position in producer that can be inlined to consumer -// Cannot inline: -// Reduction dimensions in producer -// Vectorized dimensions in producer or consumer -// Unrolled dimensions in producer or consumer -// Dimensions derived from root dimensions that exist in both but are -// unmappable -unsigned int getReplayablePosCasP( - TensorView* consumer, - TensorView* producer, - const std::unordered_set& unmappable_producer_dims, - ComputeAtMode mode) { - auto p_dom = producer->domain()->domain(); - auto first_reduction = - std::find_if(p_dom.begin(), p_dom.end(), [](IterDomain* id) { - return id->isReduction(); - }); - - auto first_vectorized_axis = - std::find_if(p_dom.begin(), first_reduction, [&mode](IterDomain* id) { - return isParallelTypeVectorize(id->getParallelType()) || - ((mode == ComputeAtMode::BestEffort || - mode == ComputeAtMode::MostInlined) && - id->getParallelType() == ParallelType::Unroll); - }); - - auto max_producer_pos = std::distance(p_dom.begin(), first_vectorized_axis); - - auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); - auto p2c_root_map = pairwise_root_map.mapProducerToConsumer( - producer->domain(), consumer->domain()); - - auto replay_CasP = - BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); - - // Look for id's that map to a consumer id that's vectorized - auto p2c_replay_map = replay_CasP.getReplay(); - - for (size_t producer_pos = max_producer_pos; producer_pos > 0; - producer_pos--) { - auto map_it = p2c_replay_map.find(producer->axis((int)producer_pos - 1)); - if (map_it != p2c_replay_map.end()) { - auto c_id = map_it->second; - // If we find a producer dim that maps to a consumer vectorized or - // unrolled dim, limit max compute at by it - if (isParallelTypeVectorize(c_id->getParallelType()) || - ((mode == ComputeAtMode::BestEffort || - mode == ComputeAtMode::MostInlined) && - c_id->getParallelType() == ParallelType::Unroll)) { - max_producer_pos = producer_pos - 1; - } - } - } - - for (size_t producer_pos = max_producer_pos; producer_pos > 0; - producer_pos--) { - auto all_vals = DependencyCheck::getAllValsBetween( - {producer->getMaybeRFactorDomain().begin(), - producer->getMaybeRFactorDomain().end()}, - {p_dom.begin(), p_dom.begin() + producer_pos}); - - // If any root dims could have mapped to consumer, but don't, then we can't - // compute at this point - if (std::any_of( - producer->getMaybeRFactorDomain().begin(), - producer->getMaybeRFactorDomain().end(), - [&unmappable_producer_dims, &all_vals](IterDomain* p_root_id) { - return std::find(all_vals.begin(), all_vals.end(), p_root_id) != - all_vals.end() && - unmappable_producer_dims.find(p_root_id) != - unmappable_producer_dims.end(); - })) { - continue; - } - - return producer_pos; - } - return 0; -} - -unsigned int getInnermostNonBroadcastIdFrom(TensorView* tv) { - unsigned int ret = tv->getComputeAtPosition(); - - // Still assuming we only have block broadcast for now. - // This part may change - while (ret > 0 && tv->axis((int)ret - 1)->isBroadcast()) { - ret--; - } - - return ret; -} - -// Try to find the aligned position on consumer's domain corresponding to the -// compute at position of producer domain. Used in computeAt pass only. No -// checking on actual producer-consumer relationship. -unsigned int getConsumerPosAlignedToProducerCA( - TensorView* consumer, - TensorView* producer) { - // Locate consumer's position that aligns with - // the producer's new compute at axis. We need broadcast axes forwarded so we - // need to replay PasC as CasP will not forward braodcast dims. For example - // if we have: - // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) - // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will - // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to - // NVFuserTest.FusionComplexBCast1_CUDA - - auto c2p_map = - BestEffortReplay::replayPasC( - producer, - consumer, - -1, - // Compute at root domain may not be valid here, as all - // producers don't have to be able to map into consumer at - // max producer position. Since computeAt should be valid - // and this mechanism is only intended to lower produce - // position of consumer, we can simply use the pairwise map. - PairwiseRootDomainMap(producer, consumer)) - .getReplay(); - - // Find the innermost position of consumer that has - // been mapped within the producer ca axis. - unsigned int consumer_pos = consumer->nDims(); - while (consumer_pos > 0) { - auto consumer_id = consumer->axis((int)consumer_pos - 1); - auto p_dom = producer->domain()->domain(); - if (std::any_of( - p_dom.begin(), - p_dom.begin() + producer->getComputeAtPosition(), - [&consumer_id, &c2p_map](IterDomain* p_id) { - auto c_id_it = c2p_map.find(consumer_id); - if (c_id_it != c2p_map.end()) { - return c_id_it->second == p_id; - } - return false; - })) { - break; - } - consumer_pos--; - } - - return consumer_pos; -} - -} // namespace - -void ComputeAt::runAt( - TensorView* producer, - TensorView* consumer, - unsigned int consumer_position, - ComputeAtMode mode) { - FUSER_PERF_SCOPE("ComputeAt::run"); - - // Make sure the correct fusion is setup between this and consumer. - TORCH_CHECK( - producer->fusion() == consumer->fusion(), - producer, - " and ", - consumer, - " are not in the same fusion."); - - // Make sure Fusion Guard is set appropriately - FusionGuard fg(producer->fusion()); - - TORCH_CHECK( - DependencyCheck::isDependencyOf(producer, consumer), - "Compute At expects ", - producer->name(), - " is a dependency of ", - consumer->name(), - ", however it is not."); - - // Run computeAt on our potentially modified producer(s) - ComputeAt ca(producer, consumer, consumer, consumer_position, mode); - ca.runPass(); -} - -void ComputeAt::runWith( +std::unordered_set getAllTVsBetween( TensorView* producer, - TensorView* consumer, - unsigned int producer_position, - ComputeAtMode mode) { - FUSER_PERF_SCOPE("ComputeAt::runWith"); - - // Make sure the correct fusion is setup between this and consumer. - TORCH_CHECK( - producer->fusion() == consumer->fusion(), - producer, - " and ", - consumer, - " are not in the same fusion."); - + TensorView* consumer) { TORCH_CHECK( DependencyCheck::isDependencyOf(producer, consumer), "Compute At expects ", @@ -334,252 +50,19 @@ void ComputeAt::runWith( " is a dependency of ", consumer->name(), ", however it is not."); - - // Make sure Fusion Guard is set appropriately - FusionGuard fg(producer->fusion()); - - ComputeAt ca(producer, consumer, producer, producer_position, mode); - ca.runPass(); + auto between_vals = + DependencyCheck::getAllValsBetween({producer}, {consumer}); + auto between_tvs = ir_utils::filterByType(between_vals); + std::unordered_set result( + between_tvs.begin(), between_tvs.end()); + result.erase(consumer); + return result; } -namespace { - -// Checks if producer and consumer are transformed consistently so that to -// satisfy the provided compute at position. This means no replay is actually -// necessary for the compute at requested. If consumer_pos then -// consumer_or_producer_pos is relative to the consumer and skipReplay returns -// the associated position in producer. -// -// If producer and consumer are not transformed consistently with provided -// postition, returns -1. -int skipReplay( - const TensorView* producer, - const TensorView* consumer, - int consumer_or_producer_pos, - bool consumer_pos = true) { - FUSER_PERF_SCOPE("transform_replay.cpp::skipReplay"); - - const auto c2p_root_map = - PairwiseRootDomainMap(producer, consumer) - .mapConsumerToProducer(consumer->domain(), producer->domain()); - - // IterDomains in consumer root also in producer root - std::unordered_set mapped_consumer_roots; - for (auto entry : c2p_root_map) { - mapped_consumer_roots.emplace(entry.first); - } - - const auto consumer_domain = consumer->domain()->domain(); - - auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween( - mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); - - std::unordered_set mapped_consumer_domain_ids( - mapped_consumer_domain_ids_vec.begin(), - mapped_consumer_domain_ids_vec.end()); - - const auto producer_domain = producer->domain()->domain(); - - auto it_consumer = consumer_domain.begin(); - auto it_producer = producer_domain.begin(); - - auto best_effort_PasC = BestEffortReplay::replayPasC( - producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)); - - auto c2p_map = best_effort_PasC.getReplay(); - - int mismatched_consumer_pos = 0; - int mismatched_producer_pos = 0; - while (it_consumer != consumer_domain.end()) { - auto consumer_id = *it_consumer; - if (!mapped_consumer_domain_ids.count(consumer_id)) { - ++it_consumer; - mismatched_consumer_pos++; - continue; - } - - auto c2p_it = c2p_map.find(consumer_id); - if (c2p_it == c2p_map.end()) { - break; - } - - if (it_producer == producer_domain.end()) { - break; - } - - auto producer_id = *it_producer; - - if (c2p_it->second == producer_id) { - ++mismatched_consumer_pos; - ++mismatched_producer_pos; - ++it_consumer; - ++it_producer; - if (consumer_pos) { - if (consumer_or_producer_pos == mismatched_consumer_pos) { - return mismatched_producer_pos; - } - } else { - if (consumer_or_producer_pos == mismatched_producer_pos) { - return mismatched_consumer_pos; - } - } - } else { - break; - } - } - return -1; -} - -} // namespace - -// Actually applies transformation -unsigned int ComputeAt::backwardComputeAt_impl( - TensorView* producer, - TensorView* consumer, - unsigned int consumer_compute_at_pos) { - FUSER_PERF_SCOPE("backwardComputeAt_impl"); - - auto max_consumer_compute_at_pos = - getReplayablePosPasC(producer, consumer, unmappable_dims_, mode_); - - if (mode_ == ComputeAtMode::BestEffort) { - consumer_compute_at_pos = - std::min(consumer_compute_at_pos, max_consumer_compute_at_pos); - } else if (mode_ == ComputeAtMode::MostInlined) { - consumer_compute_at_pos = max_consumer_compute_at_pos; - } else { - TORCH_INTERNAL_ASSERT( - consumer_compute_at_pos <= max_consumer_compute_at_pos, - "Invalid compute at position detected in compute at when trying to replay producer: ", - producer, - " as consumer: ", - consumer, - " tried to do this at position: ", - consumer_compute_at_pos, - " but max position that's allowed is ", - max_consumer_compute_at_pos); - } - - // Short cut if no replay is necessary - auto maybe_producer_pos = - skipReplay(producer, consumer, (int)consumer_compute_at_pos, true); - if (maybe_producer_pos >= 0) { - if (!producer->isFusionInput()) { - producer->setComputeAt((unsigned int)maybe_producer_pos); - } - consumer->setMaxProducer(consumer_compute_at_pos); - return (unsigned int)maybe_producer_pos; - } - - auto replay_producer_pair = TransformReplay::replayPasC( - producer, - consumer, - (int)consumer_compute_at_pos, - PairwiseRootDomainMap(producer, consumer)); - - if (replay_producer_pair.second == 0) { - return 0; - } - - if (replay_producer_pair.second >= producer->getComputeAtPosition()) { - const TensorDomain* current_domain = producer->domain(); - TensorDomain* new_domain = replay_producer_pair.first; - - TORCH_INTERNAL_ASSERT( - validateDomain(producer, new_domain), - "Tried to set the domain of ", - producer, - " to ", - new_domain, - " but that would invalidate previously compute at position or max producer position."); - - producer->setDomain(new_domain); - if (!producer->isFusionInput()) { - producer->setComputeAt(replay_producer_pair.second); - } - - consumer->setMaxProducer(consumer_compute_at_pos); - root_map_.setAlias(current_domain, new_domain); - } - - return replay_producer_pair.second; -} - -// Actually applies transformation, replay consumer based on producer, set -// compute at of producer, set pass position of consumer, return position -// relative to consumer -unsigned int ComputeAt::forwardComputeAt_impl( - TensorView* producer, - TensorView* consumer, - unsigned int producer_compute_at_pos) { - FUSER_PERF_SCOPE("forwardComputeAt_impl"); - - auto max_producer_compute_at_pos = - getReplayablePosCasP(consumer, producer, unmappable_dims_, mode_); - - if (mode_ == ComputeAtMode::BestEffort) { - producer_compute_at_pos = - std::min(producer_compute_at_pos, max_producer_compute_at_pos); - } else if (mode_ == ComputeAtMode::MostInlined) { - producer_compute_at_pos = max_producer_compute_at_pos; - } else { - TORCH_INTERNAL_ASSERT( - producer_compute_at_pos <= max_producer_compute_at_pos, - "Invalid compute at position detected in compute at when trying to replay consumer: ", - consumer, - " as producer: ", - producer, - " tried to do this at position: ", - producer_compute_at_pos, - " but max position that's allowed is ", - max_producer_compute_at_pos); - } - - // Short cut if no replay is necessary - auto maybe_consumer_pos = - skipReplay(producer, consumer, (int)producer_compute_at_pos, false); - if (maybe_consumer_pos > -1) { - if (!producer->isFusionInput()) { - producer->setComputeAt(producer_compute_at_pos); - } - consumer->setMaxProducer((unsigned int)maybe_consumer_pos); - return (unsigned int)maybe_consumer_pos; - } - - auto replay_consumer_pair = TransformReplay::replayCasP( - consumer, - producer, - (int)producer_compute_at_pos, - PairwiseRootDomainMap(producer, consumer)); - - if (producer_compute_at_pos > producer->getComputeAtPosition()) { - if (!producer->isFusionInput()) { - producer->setComputeAt((int)producer_compute_at_pos); - } - } - - if (replay_consumer_pair.second > consumer->getMaxProducerPosition()) { - const TensorDomain* current_domain = consumer->domain(); - TensorDomain* new_domain = replay_consumer_pair.first; - - TORCH_INTERNAL_ASSERT( - validateDomain(consumer, new_domain), - "Tried to set the domain of ", - consumer, - " to ", - new_domain, - " but that would invalidate previously compute at position or max producer position."); - - consumer->setDomain(new_domain); - consumer->setMaxProducer(replay_consumer_pair.second); - root_map_.setAlias(current_domain, new_domain); - } - - return replay_consumer_pair.second; -} - -void ComputeAt::setCommonConsumer() { +TensorView* getCommonConsumer(TensorView* producer, TensorView* consumer) { FUSER_PERF_SCOPE("ComputeAt::setCommonConsumer"); + auto producer_use_chains_ = + tvChains(DependencyCheck::getAllUseChains(producer)); // Convert the first chain to a set. std::set common_consumers( @@ -594,336 +77,170 @@ void ComputeAt::setCommonConsumer() { } auto all_chains = - tvChains(DependencyCheck::getAllDependencyChains(producer_, consumer_)); + tvChains(DependencyCheck::getAllDependencyChains(producer, consumer)); // Right now we only support compute at if at some point in the graph consumer // is dependent on producer. TORCH_CHECK( !all_chains.empty(), "Compute At expects ", - producer_->name(), + producer->name(), " is a dependency of ", - consumer_->name(), + consumer->name(), ", however it is not."); // Remove all TVs from producer to consumer as common consumer must be at or // after consumer for (const auto& tv_chain : all_chains) { for (auto tv : tv_chain) { - if (tv != consumer_) + if (tv != consumer) common_consumers.erase(tv); } } // If there is a common consumer, grab the first one at or after consumer - common_consumer_ = nullptr; + TensorView* common_consumer = nullptr; if (!common_consumers.empty()) { for (auto tv : producer_use_chains_.front()) { if (common_consumers.find(tv) != common_consumers.end()) { - common_consumer_ = tv; + common_consumer = tv; break; } } TORCH_INTERNAL_ASSERT( - common_consumer_ != nullptr, + common_consumer != nullptr, "Hit a logical inconsistency in the computeAt pass."); } + return common_consumer; } -// Similar to backward traversal in traverseAllKnown but we should only apply -// computeAt if it will increase computeAt positions. -void ComputeAt::traverseBackward() { - FUSER_PERF_SCOPE("ComputeAt::traverseBackward"); - if (reference_ == producer_) { - // Forward compute at don't need to run backward traversal - producer_position_ = reference_position_; - return; - } - - // propagate *backward* through all *producer* use_chains or from *producer* - // to common_consumer if common_consumer exists. Only apply transform if - // increases computeAt position. - auto chains = - tvChains(DependencyCheck::getAllDependencyChains(producer_, consumer_)); - - for (auto tv_chain : chains) { - TensorView* running_producer = tv_chain.back(); - TensorView* running_consumer = nullptr; - unsigned int running_consumer_pos = reference_position_; - tv_chain.pop_back(); - - TORCH_INTERNAL_ASSERT(running_producer == consumer_); - - while (!tv_chain.empty()) { - running_consumer = running_producer; - running_producer = tv_chain.back(); - tv_chain.pop_back(); - running_consumer_pos = backwardComputeAt_impl( - running_producer, running_consumer, running_consumer_pos); - } - - TORCH_INTERNAL_ASSERT( - running_producer == producer_, - "Compute at backward traversal ended up on something other than the producer."); - producer_position_ = running_consumer_pos; - } -} - -void ComputeAt::traverseForward() { - FUSER_PERF_SCOPE("ComputeAt::traverseForward"); - - // propagate forward through all *producer* use_chains or from *producer* to - // common_consumer if common_consumer exists. - auto chains = producer_use_chains_; - if (common_consumer_ != nullptr) { - chains = tvChains( - DependencyCheck::getAllDependencyChains(producer_, common_consumer_)); - } - - // propagate forward through all chains - for (auto tv_dep_chain : chains) { - TensorView* running_producer = nullptr; - TensorView* running_consumer = tv_dep_chain.front(); - tv_dep_chain.pop_front(); - unsigned int running_producer_pos = producer_position_; - - TORCH_INTERNAL_ASSERT(running_consumer == producer_); - - while (!tv_dep_chain.empty()) { - running_producer = running_consumer; - running_consumer = tv_dep_chain.front(); - tv_dep_chain.pop_front(); - running_producer_pos = forwardComputeAt_impl( - running_producer, running_consumer, running_producer_pos); +void pullInSiblings(std::unordered_set& s) { + for (auto tv : s) { + for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) { + if (sibling_tv == tv) { + continue; + } + s.emplace(sibling_tv); } } } -void ComputeAt::resetMaxProducerPos(TensorView* consumer_tv) { - if (consumer_tv->definition() == nullptr) { - consumer_tv->setMaxProducer(0, true); - } - - unsigned int new_consummer_pa_pos = 0; - - // Re-compute the max producer position as one or more - // of the producers of this consumer have updated their - // compute at position. - for (auto inp : ir_utils::producerTvsOf(consumer_tv)) { - if (!inp->isFusionInput()) { - // Locate consumer's position that aligns with - // the producer's new compute at axis. - unsigned int inp_ca_pos_to_consumer = - getConsumerPosAlignedToProducerCA(consumer_tv, inp); - - // Populate the max consumer position required by - // producer compute at. - new_consummer_pa_pos = - std::max(new_consummer_pa_pos, inp_ca_pos_to_consumer); - } - } - - consumer_tv->setMaxProducer(new_consummer_pa_pos, true); +// I am just trying to get the same set of tensors being transformed matching +// the previous behavior of ComputeAt. The algorithm to compute this set is +// horrible, but I don't care because I will eventually completely remove +// ComputeAt, and this algorihtm is not worse than the pervious ComputeAt. :) +std::unordered_set getPropagationSubgraph( + TensorView* producer, + TensorView* consumer) { + TORCH_CHECK( + DependencyCheck::isDependencyOf(producer, consumer), + "Compute At expects ", + producer->name(), + " is a dependency of ", + consumer->name(), + ", however it is not."); + TensorView* common_consumer = getCommonConsumer(producer, consumer); + if (common_consumer != nullptr) { + auto result = getAllTVsBetween(producer, common_consumer); + pullInSiblings(result); + return result; + } + auto result_vals = DependencyCheck::getAllDependentVals({producer}); + result_vals.emplace(producer); + auto result_tvs = ir_utils::filterByType(result_vals); + std::unordered_set result; + std::copy_if( + result_tvs.begin(), + result_tvs.end(), + std::inserter(result, result.begin()), + [](TensorView* tv) { return !tv->uses().empty(); }); + pullInSiblings(result); + return result; } -void ComputeAt::hoistInnermostBroadcast() { - auto fusion = producer_->fusion(); +} // namespace - std::unordered_set consumers_to_update; +void ComputeAt::runAt( + TensorView* producer, + TensorView* consumer, + int64_t consumer_position, + ComputeAtMode mode) { + FUSER_PERF_SCOPE("ComputeAt::runAt"); - auto all_vals = fusion->usedMathVals(); - auto all_tvs = ir_utils::filterByType(all_vals); + // Make sure the correct fusion is setup between this and consumer. + TORCH_CHECK( + producer->fusion() == consumer->fusion(), + producer, + " and ", + consumer, + " are not in the same fusion."); - for (auto running_producer : all_tvs) { - if (!running_producer->isFusionInput()) { - auto producer_ca_pos = running_producer->getComputeAtPosition(); - // Find the innermost iterdomain that is not a broadcast - auto new_ca_pos = getInnermostNonBroadcastIdFrom(running_producer); - // Update the compute at pos of this producer if the original - // compute at is within inner most broadcast axes - if (new_ca_pos < producer_ca_pos) { - running_producer->setComputeAt(new_ca_pos, true); - } - // Mark all consumers of this producer for later produce - // position update. - // This is safe with segmented fusion. TV uses will reset - // when FusionSegmentGuard try to change the IO. - auto tv_consumers = ir_utils::consumerTvsOf(running_producer); - consumers_to_update.insert(tv_consumers.begin(), tv_consumers.end()); - } + if (mode == ComputeAtMode::MostInlined) { + consumer_position = -1; } -} -void ComputeAt::updateSiblings() { - // Track which consumers may have a wrong produce at position to update - // later - auto updateSiblingsOfTv = [&](TensorView* tv) { - if (tv->definition() == nullptr) { - return; - } + FusionGuard fg(producer->fusion()); - std::unordered_set consumers_to_update; - - if (tv->definition()->outputs().size() > 1) { - auto outs = tv->definition()->outputs(); - auto out_tvs = ir_utils::filterByType(outs); - for (auto sibling_tv : out_tvs) { - if (sibling_tv == tv) { - continue; - } - - std::unordered_map tv_to_sibling_map; - TORCH_INTERNAL_ASSERT( - tv->getRootDomain().size() == sibling_tv->getRootDomain().size(), - "Error replaying multiple output expressions in computeAt."); - - // Propagate any root parallelization as fullSelfReplay expects it. - for (const auto i : c10::irange(sibling_tv->getRootDomain().size())) { - auto id = tv->getRootDomain()[i]; - auto sibling_id = sibling_tv->getRootDomain()[i]; - if (id->getParallelType() != ParallelType::Serial && - sibling_id->getParallelType() == ParallelType::Serial) { - sibling_id->parallelize(id->getParallelType()); - } else if ( - id->getParallelType() == ParallelType::Serial && - sibling_id->getParallelType() != ParallelType::Serial) { - id->parallelize(sibling_id->getParallelType()); - } - } - auto sibling_domain = - TransformReplay::fullSelfReplay(sibling_tv->domain(), tv->domain()); - validateDomain(sibling_tv, sibling_domain); - sibling_tv->setDomain(sibling_domain); - sibling_tv->setComputeAt(tv->getComputeAtPosition()); - sibling_tv->setMaxProducer(tv->getMaxProducerPosition()); - auto consumer_tvs = ir_utils::consumerTvsOf(sibling_tv); - consumers_to_update.insert(consumer_tvs.begin(), consumer_tvs.end()); - } - } + auto selected = getPropagationSubgraph(producer, consumer); + InlinePropagatorSelector selector(selected); - // Update sibling consumer tv's max producer position - for (auto consumer : consumers_to_update) { - this->resetMaxProducerPos(consumer); - } - }; + InlinePropagator inline_propagator( + consumer, consumer_position, mode, selector.selected()); + MaxProducerPosUpdater updater; - // Find all tensor views that may have been modified - auto chains = producer_use_chains_; - if (common_consumer_ != nullptr) { - chains = tvChains( - DependencyCheck::getAllDependencyChains(producer_, common_consumer_)); - } + MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector); - std::unordered_set participating_tvs; - for (auto chain : chains) { - participating_tvs.insert(chain.begin(), chain.end()); + if (mode == ComputeAtMode::MostInlined) { + MostInlinedTransformPropagator propagator; + path.traverse(&propagator); + } else { + TransformPropagator propagator(consumer, consumer_position); + path.traverse(&propagator); } - for (auto tv : participating_tvs) { - updateSiblingsOfTv(tv); - } + path.traverse(&inline_propagator); + path.traverse(&updater); } -void ComputeAt::runPass() { - FUSER_PERF_SCOPE("ComputeAt::runPass"); - - // Traverse backward through all dep chains from producer to consumer - traverseBackward(); - - // Start at producer and traverse forward through all chains - traverseForward(); +void ComputeAt::runWith( + TensorView* producer, + TensorView* consumer, + int64_t producer_position, + ComputeAtMode mode) { + FUSER_PERF_SCOPE("ComputeAt::runWith"); - // Back off on inlining the inner broadcast axes - hoistInnermostBroadcast(); + // Make sure the correct fusion is setup between this and consumer. + TORCH_CHECK( + producer->fusion() == consumer->fusion(), + producer, + " and ", + consumer, + " are not in the same fusion."); - // Update siblings of multi output expressions - updateSiblings(); + if (mode == ComputeAtMode::MostInlined) { + producer_position = -1; + } - // Update the compute at position of all consumers, this used to be done - // during the compute at pass itself, but its cleaner to do this as a cleanup - // pass similar to hoistInnermostBroadcast and updateSiblings. - std::unordered_set all_consumers; + FusionGuard fg(producer->fusion()); - // Find all tensor views that may have been modified - auto chains = producer_use_chains_; - if (common_consumer_ != nullptr) { - chains = tvChains( - DependencyCheck::getAllDependencyChains(producer_, common_consumer_)); - } + auto selected = getPropagationSubgraph(producer, consumer); + InlinePropagatorSelector selector(selected); - for (const auto& chain : chains) { - for (auto tv : chain) { - all_consumers.emplace(tv); - } - } + InlinePropagator inline_propagator( + producer, producer_position, mode, selector.selected()); + MaxProducerPosUpdater updater; - // Reset max producer position of all tensor views. - for (auto tv : all_consumers) { - resetMaxProducerPos(tv); - } -} + MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector); -void ComputeAt::buildUnmappableDims() { - auto all_tvs = ir_utils::allTvs(producer_->fusion()); - for (auto tv : all_tvs) { - auto consumers = ir_utils::consumerTvsOf(tv); - for (auto consumer : consumers) { - // Grab dimensions in producer and consumer that are mappable to eachother - // based on the computeAtRootDomainMap. This will tell us which dimensions - // can be inlined based on avoiding trying to inline non-trivial - // reduction structures. - auto mappable_roots = - root_map_.getMappableDims(tv->domain(), consumer->domain()); - for (auto tv_root_id : tv->getMaybeRFactorDomain()) { - if (mappable_roots.find(tv_root_id) == mappable_roots.end() && - !tv_root_id->isTrivialReduction()) { - unmappable_dims_.emplace(tv_root_id); - } - } - } + if (mode == ComputeAtMode::MostInlined) { + MostInlinedTransformPropagator propagator; + path.traverse(&propagator); + } else { + TransformPropagator propagator(producer, producer_position); + path.traverse(&propagator); } -} - -ComputeAt::ComputeAt( - TensorView* _producer, - TensorView* _consumer, - TensorView* _reference, - unsigned int _reference_position, - ComputeAtMode _mode) - : producer_(_producer), - consumer_(_consumer), - reference_(_reference), - reference_position_(_reference_position), - mode_(_mode) { - TORCH_INTERNAL_ASSERT( - reference_ == producer_ || reference_ == consumer_, - "For compute at reference must be producer or consumer, it's neither.", - " reference: ", - reference_, - " consumer: ", - consumer_, - " producer: ", - producer_); - TORCH_INTERNAL_ASSERT( - reference_position_ <= reference_->nDims(), - "Invalid computeAt axis, received ", - reference_position_, - " but should be > -", - reference_->nDims(), - " and <= ", - reference_->nDims(), - "."); - - producer_use_chains_ = tvChains(DependencyCheck::getAllUseChains(producer_)); - - // Look through all the use chains of producer. Check if there's a single - // consumer for all chains at or after the consumer specified in the computeAt - // call. - setCommonConsumer(); - - root_map_.build(); - - buildUnmappableDims(); + path.traverse(&inline_propagator); + path.traverse(&updater); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h index 75fca5705ed9e..98100334d72b6 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.h +++ b/torch/csrc/jit/codegen/cuda/compute_at.h @@ -1,6 +1,8 @@ #pragma once +#include #include +#include #include #include @@ -18,14 +20,14 @@ namespace cuda { class TensorDomain; class TensorView; -class ComputeAt { +struct ComputeAt { public: // Runs the compute at pass making producer look like consumer, computing // producer relative to consumer static void runAt( TensorView* producer, TensorView* consumer, - unsigned int consumer_position, + int64_t consumer_position, ComputeAtMode mode = ComputeAtMode::Standard); // Runs the compute with pass making consumer look like producer, computing @@ -33,95 +35,8 @@ class ComputeAt { static void runWith( TensorView* producer, TensorView* consumer, - unsigned int producer_position, + int64_t producer_position, ComputeAtMode mode = ComputeAtMode::Standard); - - ComputeAt() = delete; - ComputeAt(ComputeAt&) = delete; - ComputeAt& operator=(const ComputeAt& other) = delete; - - private: - TensorView* producer_; - TensorView* consumer_; - TensorView* reference_; - unsigned int reference_position_; - ComputeAtMode mode_ = ComputeAtMode::Standard; - - unsigned int producer_position_ = 0; - ComputeAtRootDomainMap root_map_; - - // Runs replayPasC and sets producer computeAt settings. Returns - // producer_compute_at_pos. - unsigned int backwardComputeAt_impl( - TensorView* producer, - TensorView* consumer, - unsigned int consumer_compute_at_pos); - - // Runs replayCasP and sets producer computeAt settings. Returns - // consumer_compute_at_pos. - unsigned int forwardComputeAt_impl( - TensorView* producer, - TensorView* consumer, - unsigned int producer_compute_at_pos); - - // Look through all the use chains of producer. Check if there's a single - // consumer for all chains at or after the consumer specified in the computeAt - // call. - void setCommonConsumer(); - - // Iterate through all TVs and collect the dimensions of each TV that don't - // map to all its consumer TVs. - void buildUnmappableDims(); - - // Propagate backward from consumer to producer, check if it increase - // computeAt position on tensors, if so take it! - void traverseBackward(); - - // Traverse from producer to common_consumer if it exists or through all uses - // of producer - void traverseForward(); - - // Looks at producer tensor views of consumer_tv, recomputes its max - // producer position, and sets max producer position. This function can - // only potentially lower the max producer position of consumer_tv. - void resetMaxProducerPos(TensorView* consumer_tv); - - // Undo the inlining of block broadcast at the innermost positions - // to avoid generating repeated block broadcasts - void hoistInnermostBroadcast(); - - // Update multi-output expressions. If one output is modified, all outputs - // should be modified as well. Propagate transformations, compute at, and - // produce at from tv to siblings. Run as final pass as it will invalidate the - // computeAt map originally computed. - void updateSiblings(); - - // Compute at pass requires tracking "maxProducerPosition" even if set simply - // from input tensor views. However, when lowering, we need a valid produce at - // position of all tensors, so inputs should never actually set their - // consumers maxProduceAt position. - void updateInputProduceAts(); - - // Run the computeAt pass - void runPass(); - - // Common consumer if it exists - TensorView* common_consumer_ = nullptr; - - // Producer use chains set in, used in a few spots. - std::deque> producer_use_chains_; - - // Root domains in producer that's unmappable to any of its consumers - std::unordered_set unmappable_dims_; - - ComputeAt( - TensorView* _producer, - TensorView* _consumer, - TensorView* _reference, - unsigned int _reference_position, - ComputeAtMode _mode); - - ~ComputeAt() = default; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 21993f98d585e..d2d4cad6d4fe8 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -33,6 +33,28 @@ IterDomainGraph::IterDomainGraph(Fusion* fusion) { build(fusion); } +//! Map corresponding inputs and outputs of swizzle op together +//! on the given disjoint set, if the given id is an output +//! of a swizzle operator. +//! +//! The current usage of swizzle operator is local to each tensor +//! itself, so they should not affect exact or permissive mapping +//! between iterdomains on different tensor domains. +//! TODO: +//! Exact mapping based index hoisting of swizzled iterdomains +//! is disabled currently and will be re-enabled in the next +//! few build out steps. +void mapMaybeSwizzleOp( + DisjointSets& disjoint_sets, + IterDomain* id) { + if (auto swizzle_2d = dynamic_cast(id->definition())) { + // Map each input to its corresponding output on the given + // disjoint set. + disjoint_sets.mapEntries(swizzle_2d->inX(), swizzle_2d->outX()); + disjoint_sets.mapEntries(swizzle_2d->inY(), swizzle_2d->outY()); + } +} + void IterDomainGraph::build(Fusion* fusion) { // Initialize a node for every iteration domain for (auto tv : ir_utils::allTvs(fusion)) { @@ -178,6 +200,12 @@ void IterDomainGraph::build(Fusion* fusion) { exact_nodes_.mapEntries(c_id, p_id); consumers_.at(p_id).pushBack(c_id); producers_.at(c_id).pushBack(p_id); + + // Add the swizzle inputs to the same + // disjoint set as well if either c_id + // or p_id is swizzle output. + mapMaybeSwizzleOp(exact_nodes_, p_id); + mapMaybeSwizzleOp(exact_nodes_, c_id); } for (auto entry : permissive_c2p_map) { @@ -189,6 +217,12 @@ void IterDomainGraph::build(Fusion* fusion) { permissive_nodes_.mapEntries(c_id, p_id); consumers_.at(p_id).pushBack(c_id); producers_.at(c_id).pushBack(p_id); + + // Add the swizzle inputs to the same + // disjoint set as well if either c_id + // or p_id is swizzle output. + mapMaybeSwizzleOp(permissive_nodes_, p_id); + mapMaybeSwizzleOp(permissive_nodes_, c_id); } // Make sure we always get root mapping for the permissive map. Because diff --git a/torch/csrc/jit/codegen/cuda/contiguity.h b/torch/csrc/jit/codegen/cuda/contiguity.h index 24f0ffa6c7e56..7293901310eb6 100644 --- a/torch/csrc/jit/codegen/cuda/contiguity.h +++ b/torch/csrc/jit/codegen/cuda/contiguity.h @@ -89,6 +89,13 @@ class ContigIDs : public OptInDispatch { void handle(Merge* merge) override; + // TODO: + // Currently not propagating any contiguity information + // as contiguity is generally not preserved after swizzles. + // But in follow ups we could gradually add back a few special + // cases, depending on specific swizzle type and axes. + void handle(Swizzle2D* swizzle) override {} + IterDomain* getCAIndexConcreteId(IterDomain* id) const; //! True if an ID is indexable. diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index eb18da1f909c3..411fd3759df41 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -83,6 +83,9 @@ void Val::dispatch(T handler, Val* val) { case ValType::TensorIndex: ptr(handler)->handle(val->as()); return; + case ValType::IntPair: + ptr(handler)->handle(val->as()); + return; default: break; } @@ -126,6 +129,9 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::Merge: ptr(handler)->handle(expr->as()); return; + case ExprType::Swizzle2D: + ptr(handler)->handle(expr->as()); + return; case ExprType::TransposeOp: ptr(handler)->handle(expr->as()); return; @@ -184,6 +190,12 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::AllocateFusedReduction: ptr(handler)->handle(expr->as()); return; + case ExprType::Swizzle2DInt: + ptr(handler)->handle(expr->as()); + return; + case ExprType::PairSelect: + ptr(handler)->handle(expr->as()); + return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } @@ -242,6 +254,9 @@ void Val::constDispatch(T handler, const Val* val) { case ValType::TensorIndex: ptr(handler)->handle(val->as()); return; + case ValType::IntPair: + ptr(handler)->handle(val->as()); + return; default: break; } @@ -285,6 +300,9 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::Merge: ptr(handler)->handle(expr->as()); return; + case ExprType::Swizzle2D: + ptr(handler)->handle(expr->as()); + return; case ExprType::TransposeOp: ptr(handler)->handle(expr->as()); return; @@ -343,6 +361,12 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::AllocateFusedReduction: ptr(handler)->handle(expr->as()); return; + case ExprType::Swizzle2DInt: + ptr(handler)->handle(expr->as()); + return; + case ExprType::PairSelect: + ptr(handler)->handle(expr->as()); + return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } @@ -409,6 +433,9 @@ void Val::mutatorDispatch(T mutator, Val* val) { case ValType::TensorIndex: ptr(mutator)->mutate(val->as()); return; + case ValType::IntPair: + ptr(mutator)->mutate(val->as()); + return; default: break; } @@ -452,6 +479,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) { case ExprType::Merge: ptr(mutator)->mutate(expr->as()); return; + case ExprType::Swizzle2D: + ptr(mutator)->mutate(expr->as()); + return; case ExprType::TransposeOp: ptr(mutator)->mutate(expr->as()); return; @@ -510,6 +540,12 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) { case ExprType::AllocateFusedReduction: ptr(mutator)->mutate(expr->as()); return; + case ExprType::Swizzle2DInt: + ptr(mutator)->mutate(expr->as()); + return; + case ExprType::PairSelect: + ptr(mutator)->mutate(expr->as()); + return; default: TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!"); } @@ -648,6 +684,9 @@ void OptOutConstDispatch::handle(const kir::Predicate* stmt) { void OptOutConstDispatch::handle(const kir::TensorIndex* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const kir::IntPair* stmt) { + unhandled(stmt); +} // Exprs void OptOutConstDispatch::handle(const UnaryOp* stmt) { @@ -684,6 +723,9 @@ void OptOutConstDispatch::handle(const Split* stmt) { void OptOutConstDispatch::handle(const Merge* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const Swizzle2D* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const TransposeOp* stmt) { unhandled(stmt); } @@ -742,6 +784,12 @@ void OptOutConstDispatch::handle(const kir::GridWelford* stmt) { void OptOutConstDispatch::handle(const kir::AllocateFusedReduction* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const kir::Swizzle2DInt* stmt) { + unhandled(stmt); +} +void OptOutConstDispatch::handle(const kir::PairSelect* stmt) { + unhandled(stmt); +} void OptOutDispatch::unhandled(Statement*) {} @@ -777,6 +825,9 @@ void OptOutDispatch::handle(kir::Predicate* stmt) { void OptOutDispatch::handle(kir::TensorIndex* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(kir::IntPair* stmt) { + unhandled(stmt); +} // Exprs void OptOutDispatch::handle(UnaryOp* stmt) { @@ -813,6 +864,9 @@ void OptOutDispatch::handle(Split* stmt) { void OptOutDispatch::handle(Merge* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(Swizzle2D* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(TransposeOp* stmt) { unhandled(stmt); } @@ -871,6 +925,12 @@ void OptOutDispatch::handle(kir::GridWelford* stmt) { void OptOutDispatch::handle(kir::AllocateFusedReduction* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(kir::Swizzle2DInt* stmt) { + unhandled(stmt); +} +void OptOutDispatch::handle(kir::PairSelect* stmt) { + unhandled(stmt); +} } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index d83f1b28bd106..f148fac430ca3 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -87,10 +87,12 @@ class ViewOp; // Exprs class Split; class Merge; +class Swizzle2D; namespace kir { class Predicate; class TensorIndex; +class IntPair; class Allocate; class BlockSync; @@ -105,6 +107,9 @@ class GridWelford; class AllocateFusedReduction; class InitMagicZero; class UpdateMagicZero; +class Swizzle2DInt; +class PairSelect; + } // namespace kir // By default, all IR nodes are handled in this dispatch, and will call an empty @@ -131,6 +136,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const kir::Predicate*); virtual void handle(const kir::TensorIndex*); + virtual void handle(const kir::IntPair*); // Exprs virtual void handle(const UnaryOp* stmt); @@ -145,6 +151,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const Split* stmt); virtual void handle(const Merge* stmt); + virtual void handle(const Swizzle2D* stmt); virtual void handle(const TransposeOp* stmt); virtual void handle(const ExpandOp* stmt); virtual void handle(const ShiftOp* stmt); @@ -165,6 +172,8 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const kir::GridBroadcast*); virtual void handle(const kir::GridWelford*); virtual void handle(const kir::AllocateFusedReduction*); + virtual void handle(const kir::Swizzle2DInt*); + virtual void handle(const kir::PairSelect*); }; class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { @@ -189,6 +198,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(kir::Predicate*); virtual void handle(kir::TensorIndex*); + virtual void handle(kir::IntPair*); // Exprs virtual void handle(UnaryOp* stmt); @@ -203,6 +213,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(Split* stmt); virtual void handle(Merge* stmt); + virtual void handle(Swizzle2D* stmt); virtual void handle(TransposeOp* stmt); virtual void handle(ExpandOp* stmt); virtual void handle(ShiftOp* stmt); @@ -223,6 +234,8 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(kir::GridBroadcast* stmt); virtual void handle(kir::GridWelford* stmt); virtual void handle(kir::AllocateFusedReduction* stmt); + virtual void handle(kir::Swizzle2DInt* stmt); + virtual void handle(kir::PairSelect* stmt); }; class TORCH_CUDA_CU_API OptInConstDispatch : public OptOutConstDispatch { @@ -288,6 +301,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(kir::Predicate*); virtual void mutate(kir::TensorIndex*); + virtual void mutate(kir::IntPair*); // Exprs virtual void mutate(UnaryOp*); @@ -302,6 +316,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(Split*); virtual void mutate(Merge*); + virtual void mutate(Swizzle2D*); virtual void mutate(TransposeOp*); virtual void mutate(ExpandOp*); virtual void mutate(ShiftOp*); @@ -322,6 +337,8 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(kir::GridBroadcast*); virtual void mutate(kir::GridWelford*); virtual void mutate(kir::AllocateFusedReduction*); + virtual void mutate(kir::Swizzle2DInt*); + virtual void mutate(kir::PairSelect*); protected: void removeExpr(IrContainer*, Expr*); diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index 78e0e3e0cba07..cab5c7bf9ad2a 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -790,6 +790,7 @@ std::vector FusionExecutor::runFusion( at::TensorOptions() .dtype(executor_entry->buffer_types[i]) .device(options_.device))); + global_buffers.zero_init.push_back(true); } else { global_buffers.buffers.push_back(at::native::empty_cuda( executor_entry->buffer_sizes[i], @@ -797,6 +798,7 @@ std::vector FusionExecutor::runFusion( c10::nullopt, options_.device, c10::nullopt)); + global_buffers.zero_init.push_back(false); } } } @@ -984,9 +986,14 @@ std::vector FusionExecutor::runFusion( << " (strides = " << output.strides() << ")" << std::endl; } std::cout << "Reduction and semaphore buffers:" << std::endl; - for (const auto& buffer : global_buffers.buffers) { + TORCH_INTERNAL_ASSERT( + global_buffers.buffers.size() == global_buffers.zero_init.size(), + "global_buffer buffer & zero_init container should have identical sizes"); + for (const auto i : c10::irange(global_buffers.buffers.size())) { + const auto& buffer = global_buffers.buffers[i]; + const auto& zero_init = global_buffers.zero_init[i]; std::cout << " " << buffer.scalar_type() << " " << buffer.sizes() - << std::endl; + << " is_zero_initialized: " << zero_init << std::endl; } } diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index 38f5ba31a63e1..3182a9273d8a8 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -101,6 +102,7 @@ std::string kernelPreamble() { ss << nvfuser_resources::tensorcore_cu; ss << nvfuser_resources::memory_cu; ss << nvfuser_resources::fused_reduction_cu; + ss << nvfuser_resources::swizzle_cu; // Random utilities ss << nvfuser_resources::PhiloxCudaStateRaw_cu; @@ -1007,6 +1009,12 @@ std::pair nvrtcCompile( if (ptxas_opt_level) { int val = atoi(ptxas_opt_level); if (val <= 4 && val >= 0) { + if (val < 4) { + TORCH_WARN( + "ptxas optimization level manually set as ", + val, + ", which could negatively affect performance. Try removing env variable PYTORCH_NVFUSER_JIT_OPT_LEVEL for optimal performance."); + } if (compile_to_sass) { jit_opt_level += std::to_string(val); args.push_back("--ptxas-options"); diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 837982facc59b..ec6243853caf1 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -30,69 +29,6 @@ namespace cuda { namespace { -// Update the HaloInfo mappings for a reference tensor by propagating -// the halo information from the consumer tensor. -void updateHaloInfoForReference( - const ReferenceTensor& reference, - const TensorView* consumer_tv) { - const auto gpu_lower = GpuLower::current(); - - auto& halo_info = gpu_lower->haloInfo(); - - auto reference_domain = reference.domain; - - // First, propagate the halo information of the consumer root domain - // to the reference root domain. - for (auto consumer_root_id : consumer_tv->getRootDomain()) { - auto consumer_index_concrete_id = gpu_lower->caMap()->getConcreteMappedID( - consumer_root_id, IdMappingMode::EXACT); - auto reference_it = - reference.concrete_to_id.find(consumer_index_concrete_id); - if (reference_it == reference.concrete_to_id.end()) { - // This happens when consumer_root_id is a broadcast or an - // initialization of a reduction buffer. In those cases, since - // the domain is not going to be predicated, it's not necessary - // to propagate halo information to the reference tensor. - continue; - } - auto reference_id = reference_it->second; - halo_info.setRootAxisInfo( - reference_id, halo_info.getRootAxisInfo(consumer_root_id)); - } - - // Now that the reference root has halo information copied from - // the cosumer, propagate it down to non-root domains. - halo_info.build(reference_domain); - - return; -} - -// Get a map of IterDomains to halo-extended extents of corresponding -// reference IterDomains. -// -// ref_map: ref-to-consumer in consumer indexing; ref-to-producer in -// producer indexing -std::unordered_map getReferenceHaloExtentMap( - const ReferenceTensor& reference, - const std::unordered_map& index_map_from_ref) { - const auto& halo_info = GpuLower::current()->haloInfo(); - - std::unordered_map reference_halo_extent_map; - - // Propagate halo extents of the reference to the consumer or - // producer tensor - for (auto kv : index_map_from_ref) { - auto ref_id = kv.first; - auto producer_or_consumer_id = kv.second; - auto extent = halo_info.getExtent(ref_id); - if (extent != nullptr) { - reference_halo_extent_map[producer_or_consumer_id] = extent; - } - } - - return reference_halo_extent_map; -} - //! Offset of an index of a producer axis with respect to its //! corresponding consumer index int getProducerHaloOffset( @@ -575,10 +511,36 @@ void IndexCompute::handle(Merge* merge) { } } +void IndexCompute::handle(Swizzle2D* swizzle_2d) { + auto out_x_id = maybeGetExactMapConcreteID(swizzle_2d->outX()); + auto out_y_id = maybeGetExactMapConcreteID(swizzle_2d->outY()); + auto in_x_id = maybeGetExactMapConcreteID(swizzle_2d->inX()); + auto in_y_id = maybeGetExactMapConcreteID(swizzle_2d->inY()); + + auto out_x_it = index_map_.find(out_x_id); + auto out_y_it = index_map_.find(out_y_id); + + if (out_x_it == index_map_.end() || out_y_it == index_map_.end()) { + return; + } + + const auto out_x_ind = out_x_it->second; + const auto out_y_ind = out_y_it->second; + + // Actual swizzle operation is handled via IndexSwizzle pass + // all behavior in this pass is directly forward through the + // index and extent. + index_map_[in_x_id] = out_x_ind; + index_map_[in_y_id] = out_y_ind; + extent_map_[in_y_id] = getExtent(out_y_id); + extent_map_[in_x_id] = getExtent(out_x_id); +} + void IndexCompute::handle(Expr* e) { switch (e->getExprType().value()) { case (ExprType::Split): case (ExprType::Merge): + case (ExprType::Swizzle2D): break; default: TORCH_INTERNAL_ASSERT( @@ -804,6 +766,16 @@ class UpdateLeafIndices : public IterVisitor { return; } + if (!index_map_.count(in_id)) { + // Reduction axes on producer side could be visited on forward + // propagation pass and current implementation does not yet + // support reduciton on swizzled iterdomains, so un-indexed + // reduction iterdomains are just ignored for now. + TORCH_INTERNAL_ASSERT( + in_id->isReduction(), "Undefined index for ", in_id->toString()); + return; + } + auto factor = split->factor(); index_map_[inner_id] = SimplifyingIrBuilder::modExpr(index_map_[in_id], factor); @@ -819,6 +791,20 @@ class UpdateLeafIndices : public IterVisitor { auto outer_id = merge->outer(); auto inner_id = merge->inner(); + if (!index_map_.count(outer_id) || !index_map_.count(inner_id)) { + // Reduction axes on producer side could be visited on forward + // propagation pass and current implementation does not yet + // support reduciton on swizzled iterdomains, so un-indexed + // reduction iterdomains are just ignored for now. + TORCH_INTERNAL_ASSERT( + outer_id->isReduction() && inner_id->isReduction(), + "Undefined index for ", + outer_id->toString(), + " and ", + inner_id->toString()); + return; + } + // Nothing need to be done when mappings for the output axes // already exist. if (index_map_.find(out_id) != index_map_.end()) { @@ -830,7 +816,7 @@ class UpdateLeafIndices : public IterVisitor { TORCH_INTERNAL_ASSERT( index_map_.find(inner_id) != index_map_.end(), "Inner ID not found"); - index_map_[out_id] = SimplifyingIrBuilder::mulExpr( + index_map_[out_id] = SimplifyingIrBuilder::addExpr( index_map_[inner_id], SimplifyingIrBuilder::mulExpr( index_map_[outer_id], getExtent(inner_id))); @@ -839,6 +825,22 @@ class UpdateLeafIndices : public IterVisitor { SimplifyingIrBuilder::mulExpr(getExtent(outer_id), getExtent(inner_id)); } + void handle(Swizzle2D* swizzle_2d) override { + auto in_x = swizzle_2d->inX(); + auto in_y = swizzle_2d->inY(); + auto out_x = swizzle_2d->outX(); + auto out_y = swizzle_2d->outY(); + + // Forward propagation pass still just forward + // through the indices and the actual swizzle + // will be applied on the backward pass in + // IndexSwizzle class implementation. + index_map_[out_x] = index_map_.at(in_x); + extent_map_[out_x] = getExtent(in_x); + index_map_[out_y] = index_map_.at(in_y); + extent_map_[out_y] = getExtent(in_y); + } + // return extent_map_[id] if exists, else return id->extent() Val* getExtent(IterDomain* id) { if (extent_map_.find(id) != extent_map_.end()) { @@ -889,6 +891,23 @@ IndexSwizzle::IndexSwizzle( swizzle_type_(tv->swizzleType()), ids_to_swizzle_(tv->axesToSwizzle()) {} +IndexSwizzle::IndexSwizzle( + const TensorView* tv, + const TensorDomain* domain, + std::unordered_map initial_index_map, + std::unordered_map extent_map, + std::unordered_set zero_domains, + std::unordered_set zero_merged_in) + : IndexCompute( + domain, + std::move(initial_index_map), + std::move(extent_map), + std::move(zero_domains), + std::move(zero_merged_in)), + tv_(tv), + swizzle_type_(tv->swizzleType()), + ids_to_swizzle_(tv->axesToSwizzle()) {} + void IndexSwizzle::run() { TORCH_INTERNAL_ASSERT( swizzle_type_ == SwizzleType::NoSwizzle || @@ -922,15 +941,35 @@ void IndexSwizzle::run() { swizzled_ids_.insert(id_to_swizzle_j); IndexCompute::run(); } + } else if (tv_->hasSwizzleOp()) { + // Propagate backward for the annotated swizzle path. + // TODO: + // eventually will unify the two swizzling implementation + // code path in a follow up. Currently just focusing on + // getting the necessary implementation of the swizzle + // operator ready. + // + // At this intermediate state, the legacy swizzle implementation + // takes precedence, i.e. whenever swizzle_type_ is not NoSwizzle, + // the new swizzle op pass is disabled. + UpdateLeafIndices update_leaves(td_, indexMap(), extentMap()); + index_map_ = update_leaves.indexMap(); + extent_map_ = update_leaves.extentMap(); + IndexCompute::run(); } } void IndexSwizzle::handle(Expr* e) { auto out_ids = ir_utils::filterByType(e->outputs()); bool needs_update = - std::any_of(out_ids.begin(), out_ids.end(), [this](IterDomain* id) { - return swizzled_ids_.find(id) != swizzled_ids_.end(); - }); + std::any_of( + out_ids.begin(), + out_ids.end(), + [this](IterDomain* id) { + return swizzled_ids_.find(id) != swizzled_ids_.end(); + }) || + (e->isA() && + e->as()->swizzleType() != Swizzle2DType::NoSwizzle); if (!needs_update) { return; } @@ -941,6 +980,48 @@ void IndexSwizzle::handle(Expr* e) { } } +void IndexSwizzle::handle(Swizzle2D* swizzle_2d) { + auto out_x_id = swizzle_2d->outX(); + auto out_y_id = swizzle_2d->outY(); + auto in_x_id = swizzle_2d->inX(); + auto in_y_id = swizzle_2d->inY(); + + auto out_x_it = index_map_.find(out_x_id); + auto out_y_it = index_map_.find(out_y_id); + + // TODO: unify the legacy path in all usage + TORCH_INTERNAL_ASSERT( + swizzle_type_ == SwizzleType::NoSwizzle, + "Cannot mix usage of two swizzle implementations"); + + TORCH_INTERNAL_ASSERT( + out_x_it != index_map_.end() && out_y_it != index_map_.end(), + "Swizzle output indices were not propagated through"); + + const auto out_x_ind = out_x_it->second; + const auto out_y_ind = out_y_it->second; + + // Can propagate zero only for a few + // swizzle types (TODO) + + if (swizzle_2d->swizzleType() != Swizzle2DType::NoSwizzle) { + auto out_pair = IrBuilder::swizzle2DIntExpr( + out_x_ind, + out_y_ind, + getExtent(out_x_id), + getExtent(out_y_id), + swizzle_2d->swizzleType()); + + index_map_[in_x_id] = + IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::X); + index_map_[in_y_id] = + IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::Y); + + swizzled_ids_.insert(in_x_id); + swizzled_ids_.insert(in_y_id); + } +} + // Used for local and shared index mapping. Returns a map from loops // to loop indices as well as a set of loops that do not contribute to // indexing. @@ -1129,45 +1210,6 @@ void ensureStaticIndexing( namespace { -// Map everything we can from reference to provided tv using the provided -// compute at map. If root_only is true, only root domains are included. -// We can't simply try to use the provided tv root domains and -// map those to the reference as the provided tv may have root domains that -// don't exist in reference. This can happen when the provided tv is from before -// a view, but all the loops are generated from TVs generated after the view -// operation. -std::unordered_map indexMapReferenceTo( - const TensorView* tv, - const std::unique_ptr& ca_map, - const std::unordered_map& - reference_concrete_to_id_map, - bool root_only = false) { - std::unordered_map index_map_ref_to_producer; - - auto gen_map = [&](const auto& pids) { - for (auto p_id : pids) { - auto concrete_id = - ca_map->getConcreteMappedID(p_id, IdMappingMode::EXACT); - auto ref_id_it = reference_concrete_to_id_map.find(concrete_id); - if (ref_id_it != reference_concrete_to_id_map.end()) { - index_map_ref_to_producer[ref_id_it->second] = p_id; - } - } - }; - - if (root_only) { - gen_map(tv->getRootDomain()); - } else { - auto all_pid_vals = DependencyCheck::getAllValsBetween( - {tv->getRootDomain().begin(), tv->getRootDomain().end()}, - {tv->domain()->domain().begin(), tv->domain()->domain().end()}); - auto all_pids = ir_utils::filterByType(all_pid_vals); - gen_map(all_pids); - } - - return index_map_ref_to_producer; -} - //! Returns an iterdomain that corresponds to the //! indexing sub-expression to hoist or a nullopt //! if the index should not be hoisted. @@ -1186,6 +1228,11 @@ c10::optional getMaybeIndexedIdToHoist( return c10::nullopt; } + // New swizzle interface not yet supported + if (tv->hasSwizzleOp()) { + return c10::nullopt; + } + // Find the true indexed domain, which can be a merged contiguous domain. auto contig_id_it = indexing.rootToContigID().find(root_id); TORCH_INTERNAL_ASSERT( @@ -1613,7 +1660,32 @@ std::vector Index::getNonGlobalProducerStridedIndices( index_swizzle.run(); - const auto& index_map = index_swizzle.indexMap(); + auto producer_swizzled_index = index_swizzle; + + if (producer_tv->hasSwizzleOp()) { + // Special handling needed on the new swizzle + // op pass: + // each swizzle op is local to the tensor, + // so ReplayPasC will not include the swizzle + // ops on the producer iterdomain. So would + // need to traverse forward the producer domain + // before the replay to get the swizzle ops. + IndexSwizzle producer_swizzle2d( + producer_tv, + domain_guard.prevDomain(), + producer_indexing.indexMap(), + producer_indexing.extentMap(), + producer_indexing.zeroDomains(), + producer_indexing.zeroMergedIn()); + producer_swizzle2d.run(); + producer_swizzled_index = producer_swizzle2d; + } + + // TODO: merge the two swizzle compute logic once the new one is ready. + // will need to replace cyclic shift swizzle with xor since swizzle2d + // doesn't have cyclic shift. + const auto& index_map = producer_swizzled_index.indexMap(); + const auto& extent_map = producer_indexing.extentMap(); const auto& zero_domain_map = producer_indexing.zeroDomains(); // Indices should now be mapped onto IterDomains in producer, so just grab @@ -2250,10 +2322,11 @@ bool needsPadding(TensorView* tv) { } // Get an additional offset of a stop index when building a predicate -// for unswitch. Initial stop indices generated at getPredicateReferenceIndexing -// do not take halo into account, and the adjustment for halo is done as an -// additional offset to the final index value so that unswitch predicates can be -// compared with each other by just looking at the additional offsets. +// for unswitch. Initial stop indices generated at +// getPredicateIndexingFromIdGraph do not take halo into account, and the +// adjustment for halo is done as an additional offset to the final index value +// so that unswitch predicates can be compared with each other by just looking +// at the additional offsets. // // consumer_root_id: the domain for which a stop predicate is being built. int getUnswitchStopOffset( @@ -2477,206 +2550,11 @@ std::pair getStartAndStopLimitOffsets( return {start_limit, stop_limit}; } -// Return an IndexCompute for a predicate reference tensor. Two different -// maps are used when generating predicates for unswitched expressions -// as start and stop conditions need to use different loop-to-index -// mappings. -auto getPredicateReferenceIndexing( - const std::vector& loops, - const ReferenceTensor& reference, - kir::ForLoop* unswitch_or_vec_loop, - IterDomain* double_buffer_axis, - bool start) { - auto reference_domain = reference.domain; - - std::unordered_map loop_to_ind_map; - - std::transform( - loops.begin(), - loops.end(), - std::inserter(loop_to_ind_map, loop_to_ind_map.begin()), - [](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); }); - - // If unswitch don't directly use indices from for loop, use zero - // and for loop extent minus 1 - if (unswitch_or_vec_loop != nullptr) { - // Vectorized predicates are different from unswitch. Unswitch predicates - // all loops within the unswitch (the outer most unswitch) are generated - // with loop->extent-1 as the index. With vectorized predicates, only the - // vectorized loop should be like this. - - bool vectorized_pred = - unswitch_or_vec_loop->iter_domain()->getParallelType() == - ParallelType::Vectorize; - - TORCH_INTERNAL_ASSERT( - loops.size() <= reference_domain->nDims(), - "Invalid reference generated."); - - bool within_unswitch = false; - - for (const auto loop_i : c10::irange(loops.size())) { - auto loop = loops[loop_i]; - auto loop_id = loop->iter_domain(); - auto loop_pt = loop_id->getParallelType(); - auto ref_id = reference_domain->axis(loop_i); - - if (loop == unswitch_or_vec_loop) { - within_unswitch = true; - } - - if (within_unswitch) { - // Rely on the reference to check broadcasting. The for loop could be - // broadcasted on a constant value from an unroll split. Since reference - // may convert this to an iter domain, that for loop could be valid to - // generate predication from. - - // Note that loop->stop() is not used below. Instead, - // loop->iter_domain()->extent() is used, which is uniform - // across the mapped domains irrespective of halo. Predicates are - // compared with each to pick the most restrictive ones. The - // comparison is done by only using the offset, which is the - // term added to the index. So, the index term must be the - // same among all predicates, otherwise the comparison would - // be invalid. The effect by halo is added to the offset - // term. See getUnswitchStopOffset. - - if (ref_id->isBroadcast()) { - // Ignore indexing into broadcasted dimensions. - continue; - } else if (loop_id->isThread()) { - // When parallelized, if the loop stop is the same as the - // extent of the associated IterDomain, i.e., no extra - // iterations for halo, predicating with the threading index - // is sufficient for both the start and stop - // predicates. That isn't the case if the loop has halo, and - // in the case either the minimum and maximum values of the - // iteration domain needs to be used. - // - // Note: Better performance was obtained if using - // threadIdx in unswitch predicates was avoided. More - // specifically, in the Hdiff stencil example, instead of - // predicating with threadIdx.x for both the start and stop - // predicates, using zero and (blockDim.x - 1) for the start - // and stop predicates, respectively, resulted in less - // register pressure. The alternative codegen can be done by - // adding this to the first if condition: - // loop_id->isBlockDim(). This would not be a concern if the - // else part could be omitted, so canOmitElseClause should - // be used as well. - if (loop->stop() == loop_id->extent()) { - loop_to_ind_map[loop] = loop->start(); - } else if (start) { - loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal(); - } else { - // Note that the parallel dimension is used rather than - // loop-stop(). See the above comment. - loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr( - GpuLower::current()->parallelDimensionMap().get(loop_pt), - GpuLower::current()->kernel()->zeroVal()); - } - } else if (start) { - loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal(); - } else { - // Similar to the above, loop_id()->extent() is - // used here instead of loop->stop(). See the above comment. - loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr( - loop_id->extent(), GpuLower::current()->kernel()->oneVal()); - } - } - - // If a vectorized predicate, bail after the vectorized loop was found. - // Don't continue unswitching loops. - if (vectorized_pred && within_unswitch) { - break; - } - } - } - - for (const auto loop : loops) { - auto& idx = loop_to_ind_map.at(loop); - // If the loop is trivial, the loop index can only be the loop - // start value. - if (idx == loop->index() && loop->isTrivial()) { - idx = loop->start(); - } - } - - if (double_buffer_axis != nullptr) { - auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( - double_buffer_axis, loops, true); - if (db_loop != nullptr) { - auto loop_to_ind_map_it = loop_to_ind_map.find(db_loop); - TORCH_INTERNAL_ASSERT(loop_to_ind_map_it != loop_to_ind_map.end()); - auto cur_index = loop_to_ind_map_it->second; - // if cur_index is not the same as the index of db_loop, it must - // be true that that index has been modified to support - // unswitch. In that case, it is not necessary to move ahead the - // index for double buffering. - if (cur_index == db_loop->index()) { - loop_to_ind_map[db_loop] = SimplifyingIrBuilder::addExpr( - cur_index, GpuLower::current()->kernel()->oneVal()); - } - } - } - - // Add magic zero to a loop pretty far inside in indexing - IterDomain* magic_zero_loop = nullptr; - std::unordered_map ref_id_to_ind_map; - // Due to rfactor/initialization reference_domain may be bigger than loop nest - // structure - TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims()); - for (const auto loop_i : c10::irange(loops.size())) { - auto loop = loops[loop_i]; - auto ind = loop_to_ind_map[loops[loop_i]]; - auto ref_axis = reference_domain->axis(loop_i); - - if (Index::protectWithMagicZero(loop, ref_axis, ind)) { - magic_zero_loop = ref_axis; - } - - ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loop]; - } - - if (ref_id_to_ind_map.count(magic_zero_loop)) { - auto& ind = ref_id_to_ind_map[magic_zero_loop]; - if (!ind->isConstScalar()) { - ind = SimplifyingIrBuilder::addExpr( - ind, GpuLower::current()->kernel()->magicZeroVal()); - } - } - - std::unordered_map ref_self_map; - auto all_vals = DependencyCheck::getAllValsBetween( - {reference_domain->getRootDomain().begin(), - reference_domain->getRootDomain().end()}, - {reference_domain->domain().begin(), reference_domain->domain().end()}); - auto all_ids = ir_utils::filterByType(all_vals); - std::for_each(all_ids.begin(), all_ids.end(), [&ref_self_map](auto id) { - ref_self_map.insert({id, id}); - }); - - std::unordered_map reference_halo_extent_map = - getReferenceHaloExtentMap(reference, ref_self_map); - - // Index into the reference tensor - auto index_compute = getReferenceIndexing( - loops, - reference_domain, - ref_id_to_ind_map, - {}, - {}, - reference_halo_extent_map); - - return index_compute; -} - // Get the offsets for the start and stop predicates. The offsets // are to be added to the index. std::pair getStartAndStopOffsets( IterDomain* consumer_id, TensorView* consumer_tv, - const ReferenceTensor& reference, const std::unordered_map& consumer_start_index_map, const std::unordered_map& consumer_stop_index_map, bool padding_predicate, @@ -2720,7 +2598,7 @@ std::pair getStartAndStopOffsets( // If generating a predicate for unswitch, adjust the stop offset to // accommodate the addition of halo to the loop stop. See the - // comment in getPredicateReferenceIndexing as well. + // comment in getPredicateIndexingFromIdGraph as well. if (unswitch) { TORCH_INTERNAL_ASSERT( !padding_predicate, "Unswitch should not use the padding predicate"); @@ -2842,12 +2720,12 @@ std::pair hoistPredicates( Val* start_index, Val* stop_index, const std::vector& loops, + std::vector loop_domains, + const std::unordered_map start_initial_loop_index_map, + const std::unordered_map stop_initial_loop_index_map, kir::ForLoop* unswitch_or_vec_loop, IterDomain* predicated_consumer_id, - TensorView* predicated_consumer_tv, - TensorDomain* ref_td, - const std::unordered_map& ref_start_index_map, - const std::unordered_map& ref_stop_index_map) { + TensorView* predicated_consumer_tv) { const std::pair same_indices{start_index, stop_index}; if (isDisabled(DisableOption::IndexHoist)) { @@ -2867,8 +2745,8 @@ std::pair hoistPredicates( GpuLower::current()->commonIndexMap().insert( predicated_consumer_id, predicated_consumer_tv->domain(), - ref_td, - ref_stop_index_map, + loop_domains, + stop_initial_loop_index_map, loops, stop_index); } @@ -2884,8 +2762,8 @@ std::pair hoistPredicates( GpuLower::current()->commonIndexMap().insert( predicated_consumer_id, predicated_consumer_tv->domain(), - ref_td, - ref_start_index_map, + loop_domains, + start_initial_loop_index_map, loops, start_index); } @@ -2896,12 +2774,11 @@ std::pair hoistPredicates( } // namespace // Returns predicates and the concrete (by loop map) root domains they cover -std::pair, ReferenceTensor> Index:: - getReferenceRootPredicates( - TensorView* consumer_tv, - const std::vector& loops, - kir::ForLoop* unswitch_or_vec_loop, - bool shift_padding) { +std::vector Index::getReferenceRootPredicates( + TensorView* consumer_tv, + const std::vector& loops, + kir::ForLoop* unswitch_or_vec_loop, + bool shift_padding) { FUSER_PERF_SCOPE("GpuLower::Lower::Index::getReferenceRootPredicates"); const auto gpu_lower = GpuLower::current(); @@ -2910,22 +2787,9 @@ std::pair, ReferenceTensor> Index:: // Nothing needs to be done when padding is not required. if (shift_padding && !needsPadding(consumer_tv)) { - return {{RootPredicateInfo::getFalseInfo()}, ReferenceTensor{}}; + return {RootPredicateInfo::getFalseInfo()}; } - // Get a reference tensor replayed as existing loop structure - ReferenceTensor reference = - IndexReferenceReplay::getReference(loops, consumer_tv); - - // Generate halo information for reference. - updateHaloInfoForReference(reference, consumer_tv); - - const auto ref_2_consumer = indexMapReferenceTo( - consumer_tv, gpu_lower->caMap(), reference.concrete_to_id); - - const auto reference_halo_extent_map = - getReferenceHaloExtentMap(reference, ref_2_consumer); - auto db_axis = gpu_lower->doubleBufferInfo().getDoubleBufferAxis(consumer_tv); // Indexing is done without considering contig merging. Actual @@ -2936,38 +2800,27 @@ std::pair, ReferenceTensor> Index:: std::vector(consumer_tv->getMaybeRFactorDomain().size(), false), {}); + // Generate start and stop indexing from idgraph. + // // Both start and stop positions may need to be predicated. Indexing // differs when generating predicates for unswitch. // NOTE: If we could find-and-replace KIR nodes, we could just // generate one index map, clone it and replace the loop-to-index // mappings of unswitched loops for the start predicate. - auto ref_stop_indexing = getPredicateReferenceIndexing( - loops, reference, unswitch_or_vec_loop, db_axis, false); - const auto consumer_stop_indexing = ref_stop_indexing.updateIndexCompute( - consumer_tv->domain(), - ref_2_consumer, - contig_finder, - reference_halo_extent_map); + + auto stop_indexing_from_idgraph = getPredicateIndexingFromIdGraph( + loops, consumer_tv, unswitch_or_vec_loop, db_axis, false); + const auto consumer_stop_indexing = stop_indexing_from_idgraph.index; const auto& consumer_stop_index_map = consumer_stop_indexing.indexMap(); // If not unswitch, share the same indexing map as the stop index // map - const auto& ref_start_indexing = is_unswitch - ? getPredicateReferenceIndexing( - loops, reference, unswitch_or_vec_loop, db_axis, true) - : ref_stop_indexing; - - std::unordered_map consumer_start_index_map; - if (is_unswitch) { - const auto consumer_start_indexing = ref_start_indexing.updateIndexCompute( - consumer_tv->domain(), - ref_2_consumer, - contig_finder, - reference_halo_extent_map); - consumer_start_index_map = consumer_start_indexing.indexMap(); - } else { - consumer_start_index_map = consumer_stop_index_map; - } + const auto start_indexing_from_idgraph = is_unswitch + ? getPredicateIndexingFromIdGraph( + loops, consumer_tv, unswitch_or_vec_loop, db_axis, true) + : stop_indexing_from_idgraph; + const auto consumer_start_indexing = start_indexing_from_idgraph.index; + const auto& consumer_start_index_map = consumer_start_indexing.indexMap(); // Get the contiguous ids we need to generate predicates for auto contig_id_infos = @@ -3024,7 +2877,6 @@ std::pair, ReferenceTensor> Index:: std::tie(info.start_offset_, info.stop_offset_) = getStartAndStopOffsets( contig_id, consumer_tv, - reference, consumer_start_index_map, consumer_stop_index_map, shift_padding, @@ -3038,12 +2890,12 @@ std::pair, ReferenceTensor> Index:: start_index, stop_index, loops, + stop_indexing_from_idgraph.resolved_loop_domains, + start_indexing_from_idgraph.initial_concrete_index_map, + stop_indexing_from_idgraph.initial_concrete_index_map, unswitch_or_vec_loop, contig_id, - consumer_tv, - reference.domain, - ref_start_indexing.indexMap(), - ref_stop_indexing.indexMap()); + consumer_tv); // Build predicates for start positions as: // start_index + start_offset >= 0 @@ -3080,7 +2932,7 @@ std::pair, ReferenceTensor> Index:: pred_info_vec.emplace_back(info); } - return {pred_info_vec, reference}; + return pred_info_vec; } bool Index::protectWithMagicZero( diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index e9386f5a53a2c..e90161aab0b7e 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -70,6 +70,7 @@ class IndexCompute : public BackwardVisitor { void handle(Split*) override; void handle(Merge*) override; void handle(Expr*) override; + void handle(Swizzle2D*) override; // return extent_map_[id] if exists, else return id->extent() Val* getExtent(IterDomain* id) const; @@ -204,6 +205,14 @@ class IndexSwizzle : public IndexCompute { std::unordered_set zero_domains, std::unordered_set zero_merged_in); + IndexSwizzle( + const TensorView* tv, + const TensorDomain* domain, + std::unordered_map initial_index_map, + std::unordered_map extent_map, + std::unordered_set zero_domains, + std::unordered_set zero_merged_in); + void run() override; protected: @@ -211,6 +220,8 @@ class IndexSwizzle : public IndexCompute { void handle(Expr* e) override; + void handle(Swizzle2D* swizzle_2d) override; + private: const TensorView* tv_ = nullptr; SwizzleType swizzle_type_ = SwizzleType::NoSwizzle; @@ -345,8 +356,7 @@ class Index { //! this is not a bool value as if we have an unswitch loop with a vectorized //! loop inside, we only want to base the "unswitch" like predicate on the //! vectorized loop. - static std::pair, ReferenceTensor> - getReferenceRootPredicates( + static std::vector getReferenceRootPredicates( TensorView* consumer_tv, const std::vector& loops, kir::ForLoop* unswitch_or_vec_loop, diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp new file mode 100644 index 0000000000000..d35c72e3a61d3 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp @@ -0,0 +1,368 @@ +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +bool InlinePropagatorSelector::allowC2P(TensorView* from, TensorView* to) { + return selected_.count(to) > 0; +} + +bool InlinePropagatorSelector::allowP2C(TensorView* from, TensorView* to) { + // If the producer is in the selected set, then the consumer must also be + // replayed to obtain a compatible loop structure so that this producer + // can be consumed in this loop. + return selected_.count(from) > 0 || selected_.count(to) > 0; +} + +bool InlinePropagatorSelector::allowSibling(TensorView* from, TensorView* to) { + return true; +} + +MaxPosCalculator::MaxPosCalculator(ComputeAtMode mode) : mode_(mode) { + buildUnmappableDims(); +} + +void MaxPosCalculator::buildUnmappableDims() { + ComputeAtRootDomainMap root_map; + root_map.build(); + + auto all_tvs = ir_utils::allTvs(FusionGuard::getCurFusion()); + for (auto tv : all_tvs) { + auto consumers = ir_utils::consumerTvsOf(tv); + for (auto consumer : consumers) { + // Grab dimensions in producer and consumer that are mappable to eachother + // based on the computeAtRootDomainMap. This will tell us which dimensions + // can be inlined based on avoiding trying to inline non-trivial + // reduction structures. + auto mappable_roots = + root_map.getMappableDims(tv->domain(), consumer->domain()); + for (auto tv_root_id : tv->getMaybeRFactorDomain()) { + if (mappable_roots.find(tv_root_id) == mappable_roots.end() && + !tv_root_id->isTrivialReduction()) { + unmappable_dims_.emplace(tv_root_id); + } + } + } + } +} + +bool MaxPosCalculator::isAllowedID( + IterDomain* id, + TensorView* tv, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) const { + bool allowed = true; + + if (!allow_reduction) { + allowed = allowed && !id->isReduction(); + } + + if (!allow_vectorize) { + // Avoid inlining if marked as Vectorize or Group. In the case of + // BestEffort and MostInlined modes, avoid Unroll as well. + bool is_vectorize = isParallelTypeVectorize(id->getParallelType()) || + id->getParallelType() == ParallelType::Group || + ((mode_ == ComputeAtMode::BestEffort || + mode_ == ComputeAtMode::MostInlined) && + id->getParallelType() == ParallelType::Unroll); + allowed = allowed && !is_vectorize; + } + + if (!allow_unmappable) { + auto root_dom = tv->getMaybeRFactorDomain(); + std::unordered_set root_dom_set(root_dom.begin(), root_dom.end()); + auto all_vals = DependencyCheck::getAllValsBetween(root_dom_set, {id}); + bool is_unmappable = false; + for (auto val : all_vals) { + auto id = val->as(); + if (root_dom_set.count(val) > 0 && unmappable_dims_.count(id) > 0) { + is_unmappable = true; + break; + } + } + allowed = allowed && !is_unmappable; + } + + return allowed; +} + +size_t MaxPosCalculator::getMaxPosSelf( + TensorView* tv, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) const { + auto dom = tv->domain()->domain(); + auto iter = std::find_if(dom.begin(), dom.end(), [=](IterDomain* id) { + return !isAllowedID( + id, tv, allow_reduction, allow_vectorize, allow_unmappable); + }); + return std::distance(dom.begin(), iter); +} + +// Return the max position in producer that can be inlined to consumer +// Cannot inline: +// Vectorized dimensions in consumer +// Unrolled dimensions in consumer +size_t MaxPosCalculator::getMaxProducerPosFromConsumer( + TensorView* producer, + TensorView* consumer) const { + auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); + auto replay_CasP = + BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); + auto p2c_replay_map = replay_CasP.getReplay(); + + for (size_t producer_pos = 0; producer_pos < producer->nDims(); + producer_pos++) { + auto map_it = p2c_replay_map.find(producer->axis(producer_pos)); + if (map_it != p2c_replay_map.end()) { + auto c_id = map_it->second; + if (!isAllowedID(c_id, consumer, true, false, true)) { + return producer_pos; + } + } + } + return producer->nDims(); +} + +size_t InlinePropagator::getMaxPosAll(TensorView* tv, bool check_siblings) { + auto max_pos = max_pos_calc.getMaxPosSelf(tv, false, false, false); + for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { + max_pos = std::min( + max_pos, max_pos_calc.getMaxProducerPosFromConsumer(tv, consumer_tv)); + } + if (check_siblings) { + for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) { + max_pos = std::min(max_pos, getMaxPosAll(sibling_tv, false)); + } + } + return max_pos; +} + +void InlinePropagator::setCAPos(TensorView* tv) { + size_t pos = mapped_reference_pos_.at(tv); + if ((selected_.empty() || selected_.count(tv)) && !tv->isFusionInput()) { + auto max_pos = getMaxPosAll(tv); + if (mode_ == ComputeAtMode::Standard) { + TORCH_INTERNAL_ASSERT( + pos <= max_pos, + "Invalid compute at position detected in InlinePropagator when trying to set the CA position of: ", + tv, + " to ", + pos, + ", max position that's allowed is ", + max_pos); + } else { + pos = std::min(pos, max_pos); + } + // hoist inner most broadcast + while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) { + pos--; + } + tv->setComputeAt(pos); + } +} + +InlinePropagator::InlinePropagator( + TensorView* reference, + int64_t reference_pos, + ComputeAtMode mode, + std::unordered_set selected) + : max_pos_calc(mode), + selected_(std::move(selected)), + reference_(reference), + mode_(mode) { + if (reference_pos < 0) { + reference_pos += int64_t(reference->nDims()) + 1; + } + TORCH_INTERNAL_ASSERT( + reference_pos >= 0 && reference_pos <= reference->nDims(), + "Invalid computeAt axis, received ", + reference_pos, + " but should be > -", + reference->nDims(), + " and <= ", + reference->nDims(), + "."); + reference_pos_ = reference_pos; +} + +void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) { + if (is_first_) { + is_first_ = false; + mapped_reference_pos_[reference_] = reference_pos_; + setCAPos(reference_); + } + // Step 1: find mapped_reference_pos_[to] + int from_pos; + if (mode_ != ComputeAtMode::MostInlined) { + from_pos = mapped_reference_pos_.at(from); + } else { + from_pos = from->nDims(); + } + auto to_pos = + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos); + TORCH_CHECK( + to_pos >= 0, + "Unable to propagate CA position from consumer ", + from, + " at ", + from_pos, + " to producer ", + to, + " because this would require replay."); + mapped_reference_pos_[to] = to_pos; + // Step 2: set CA position of `to` + setCAPos(to); +} + +void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) { + if (is_first_) { + is_first_ = false; + mapped_reference_pos_[reference_] = reference_pos_; + setCAPos(reference_); + } + // Step 1: find mapped_reference_pos_[to] + int from_pos; + if (mode_ != ComputeAtMode::MostInlined) { + from_pos = mapped_reference_pos_.at(from); + } else { + from_pos = from->nDims(); + } + auto to_pos = + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos); + TORCH_CHECK( + to_pos >= 0, + "Unable to propagate CA position from producer ", + from, + " at ", + from_pos, + " to consumer ", + to, + " because this would require replay."); + mapped_reference_pos_[to] = to_pos; + // Step 2: set CA position of `to` + setCAPos(to); +} + +void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) { + if (is_first_) { + is_first_ = false; + mapped_reference_pos_[reference_] = reference_pos_; + setCAPos(reference_); + } + // Step 1: find mapped_reference_pos_[to] + auto from_pos = mapped_reference_pos_.at(from); + TORCH_CHECK( + TransformReplay::fullSelfMatching(to, from), + "Unable to propagate CA position from ", + from, + " to sibling ", + to, + " because this would require replay."); + mapped_reference_pos_[to] = from_pos; + // Step 2: set CA position of `to` + setCAPos(to); +} + +namespace { + +// Try to find the aligned position on consumer's domain corresponding to the +// compute at position of producer domain. Used in computeAt pass only. No +// checking on actual producer-consumer relationship. +unsigned int getConsumerPosAlignedToProducerCA( + TensorView* consumer, + TensorView* producer) { + // Locate consumer's position that aligns with + // the producer's new compute at axis. We need broadcast axes forwarded so we + // need to replay PasC as CasP will not forward braodcast dims. For example + // if we have: + // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) + // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will + // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to + // NVFuserTest.FusionComplexBCast1_CUDA + + auto c2p_map = + BestEffortReplay::replayPasC( + producer, + consumer, + -1, + // Compute at root domain may not be valid here, as all + // producers don't have to be able to map into consumer at + // max producer position. Since computeAt should be valid + // and this mechanism is only intended to lower produce + // position of consumer, we can simply use the pairwise map. + PairwiseRootDomainMap(producer, consumer)) + .getReplay(); + + // Find the innermost position of consumer that has + // been mapped within the producer ca axis. + unsigned int consumer_pos = consumer->nDims(); + while (consumer_pos > 0) { + auto consumer_id = consumer->axis((int)consumer_pos - 1); + auto p_dom = producer->domain()->domain(); + if (std::any_of( + p_dom.begin(), + p_dom.begin() + producer->getComputeAtPosition(), + [&consumer_id, &c2p_map](IterDomain* p_id) { + auto c_id_it = c2p_map.find(consumer_id); + if (c_id_it != c2p_map.end()) { + return c_id_it->second == p_id; + } + return false; + })) { + break; + } + consumer_pos--; + } + + return consumer_pos; +} + +} // namespace + +// Try to find the aligned position on consumer's domain corresponding to the +// compute at position of producer domain. +void MaxProducerPosUpdater::handle(TensorView* consumer) { + unsigned int consumer_pos = 0; + for (auto producer : ir_utils::producerTvsOf(consumer)) { + consumer_pos = std::max( + consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer)); + } + consumer->setMaxProducer(consumer_pos); +} + +void MaxProducerPosUpdater::propagateC2P(TensorView* from, TensorView* to) { + if (updated_.empty()) { + // handle the reference tensor + updated_.insert(nullptr); + propagateC2P(nullptr, from); + } + for (auto consumer_tv : ir_utils::consumerTvsOf(to)) { + if (updated_.count(consumer_tv) > 0) { + continue; + } + handle(consumer_tv); + updated_.insert(consumer_tv); + } +} + +void MaxProducerPosUpdater::propagateP2C(TensorView* from, TensorView* to) { + propagateC2P(from, to); +} + +void MaxProducerPosUpdater::propagateSibling(TensorView* from, TensorView* to) { + propagateC2P(from, to); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h new file mode 100644 index 0000000000000..d7cd1f82a8d90 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/inline_propagator.h @@ -0,0 +1,143 @@ +#pragma once + +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +// Simple selector that only propagates across tensor views in the provided +// unordered_set. Will also propagate to all consumers of those tensors, and the +// siblings of those tensors. +class TORCH_CUDA_CU_API InlinePropagatorSelector + : public MaxInfoSpanningTree::Selector { + std::unordered_set selected_; + + public: + virtual bool allowC2P(TensorView* from, TensorView* to) override; + virtual bool allowP2C(TensorView* from, TensorView* to) override; + virtual bool allowSibling(TensorView* from, TensorView* to) override; + + InlinePropagatorSelector(std::unordered_set selected) + : selected_(std::move(selected)){}; + const std::unordered_set& selected() const { + return selected_; + } +}; + +class TORCH_CUDA_CU_API MaxPosCalculator { + ComputeAtMode mode_ = ComputeAtMode::Standard; + + // Root domains in producer that's unmappable to any of its consumers + std::unordered_set unmappable_dims_; + + // Iterate through all TVs and collect the dimensions of each TV that don't + // map to all its consumer TVs. + void buildUnmappableDims(); + + // Utility function to return if an id of tv is a valid iter domain to inline + // within. This is used in getMaxPos{PasC,CasP}. Different variations of the + // bool values are used if checking max position of PasC, CasP, or checking + // for a max "self" position. + bool isAllowedID( + IterDomain* id, + TensorView* tv, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) const; + + public: + // Returns the position at which tv can be inlined within. + size_t getMaxPosSelf( + TensorView* tv, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) const; + + // Returns the maximum position producer can be inlined based on consumer + // given the set ComputeAtMode + size_t getMaxProducerPosFromConsumer( + TensorView* producer, + TensorView* consumer) const; + + MaxPosCalculator(ComputeAtMode mode); +}; + +// Propagate inline position to the `selected` tensors in the DAG. If `selected` +// is not specified or empty, then propagate to the entire DAG. +class TORCH_CUDA_CU_API InlinePropagator + : public MaxInfoSpanningTree::Propagator { + // Checks producers and consumers to see what the maximum position in tv is + // that can be shared across both directions. + size_t getMaxPosAll(TensorView* tv, bool check_siblings = true); + + // We use mapped_reference_pos_ to keep track of the outer axes information of + // the reference tensor. That is, mapped_reference_pos_[tv] answers the + // question "What outer axes in tv are shared with the specified reference + // tensor's outer axes?". However, when we actually set the CA position of tv, + // we might not want to set it as mapped_reference_pos_[tv] because because we + // don't want to inline certain things (such as vectorized dimensions, inner + // most broadcasting, etc.). + std::unordered_map mapped_reference_pos_; + + // Actually set the computeAt position. This does not necessarily equal to + // mapped_reference_pos_[tv] because we don't want to inline certain things. + void setCAPos(TensorView* tv); + + const MaxPosCalculator max_pos_calc; + std::unordered_set selected_; + TensorView* reference_; + size_t reference_pos_; + ComputeAtMode mode_ = ComputeAtMode::Standard; + bool is_first_ = true; + + public: + InlinePropagator( + TensorView* reference, + int64_t reference_pos, + ComputeAtMode mode = ComputeAtMode::Standard, + std::unordered_set selected = {}); + + InlinePropagator( + TensorView* reference, + int64_t reference_pos, + std::unordered_set selected) + : InlinePropagator( + reference, + reference_pos, + ComputeAtMode::Standard, + selected) {} + + ~InlinePropagator() = default; + + // Actually propagate the transformations for the inlining pass. Uses the + // functions above to figure out what position to do the propagation at. + virtual void propagateC2P(TensorView* from, TensorView* to) override; + virtual void propagateP2C(TensorView* from, TensorView* to) override; + virtual void propagateSibling(TensorView* from, TensorView* to) override; +}; + +// This is actually not a propagation, it only sets the max producer position of +// the tensors, and it is not needed to compute the max producer position in a +// specific order. But MaxInfoSpanningTree provides a very convenient API to +// visit the tensors, so I just use it for cleaner code. +class TORCH_CUDA_CU_API MaxProducerPosUpdater + : public MaxInfoSpanningTree::Propagator { + std::unordered_set updated_; + void handle(TensorView* tv); + + public: + virtual void propagateC2P(TensorView* from, TensorView* to) override; + virtual void propagateP2C(TensorView* from, TensorView* to) override; + virtual void propagateSibling(TensorView* from, TensorView* to) override; +}; + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.cpp b/torch/csrc/jit/codegen/cuda/ir_builder.cpp index 29843676b2b6e..eefaf361fd0d0 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_builder.cpp @@ -53,6 +53,7 @@ IR_BUILDER_INSTANTIATE(NamedScalar) // Exprs IR_BUILDER_INSTANTIATE(Split) IR_BUILDER_INSTANTIATE(Merge) +IR_BUILDER_INSTANTIATE(Swizzle2D) IR_BUILDER_INSTANTIATE(TransposeOp) IR_BUILDER_INSTANTIATE(ExpandOp) IR_BUILDER_INSTANTIATE(ShiftOp) @@ -195,6 +196,27 @@ Val* IrBuilder::minExpr(Val* lhs, Val* rhs) { return newArithmeticExpr(BinaryOpType::Min, lhs, rhs); } +Val* IrBuilder::swizzle2DIntExpr( + Val* in_x, + Val* in_y, + Val* extent_x, + Val* extent_y, + Swizzle2DType swizzle_type) { + auto result = create(); + + create( + result, in_x, in_y, extent_x, extent_y, swizzle_type); + return result; +} + +Val* IrBuilder::pairSelectExpr(Val* in, kir::PairSelect::Selection sel) { + auto int_pair = dynamic_cast(in); + TORCH_INTERNAL_ASSERT(int_pair != nullptr); + auto result = create(); + create(result, int_pair, sel); + return result; +} + Val* SimplifyingIrBuilder::negExpr(Val* val) { if (auto int_val = dynamic_cast(val)) { if (int_val->isConst()) { diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.h b/torch/csrc/jit/codegen/cuda/ir_builder.h index f122232f8fb8e..af0e8cb1cc355 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.h +++ b/torch/csrc/jit/codegen/cuda/ir_builder.h @@ -91,6 +91,15 @@ class TORCH_CUDA_CU_API IrBuilder { // Ternary operations static Val* whereExpr(Val* pred, Val* lhs, Val* rhs); + // Swizzle operations + static Val* swizzle2DIntExpr( + Val* x, + Val* y, + Val* extent_x, + Val* extent_y, + Swizzle2DType swizzle_type); + static Val* pairSelectExpr(Val* in, kir::PairSelect::Selection sel); + private: static Val* newResult(DataType dtype); static Val* newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs); diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index 1da8b1e23a5ab..638e9d8c5a5f1 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -156,6 +156,10 @@ void IrCloner::handle(const Merge* merge) { clone_ = IrBuilder::clone(merge, this); } +void IrCloner::handle(const Swizzle2D* swizzle) { + clone_ = IrBuilder::clone(swizzle, this); +} + TensorView* RecomputeTv::recompute(TensorView* tv) { FusionGuard fg(tv->fusion()); diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h index 349fd68422238..a0f5d76f007d8 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.h +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h @@ -86,6 +86,7 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { void handle(const Split*) override; void handle(const Merge*) override; + void handle(const Swizzle2D*) override; protected: // We keep track of the original -> clone map so we don't diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 46bf0801e2e9e..fa49e7bd9dba6 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -154,8 +154,10 @@ class TORCH_CUDA_CU_API ComplexDouble : public Val { //! the compute at position to maximum possible through traversal. enum class ComputeAtMode { Standard, BestEffort, MostInlined }; -class ComputeAt; +class InlinePropagator; +class MaxProducerPosUpdater; class TransformPropagator; +struct MostInlinedTransformPropagator; class TransformIter; class TransformReplay; class OptOutMutator; @@ -377,6 +379,10 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! \input axes Axes to swizzle TensorView* swizzle(SwizzleType type, const std::vector& axes); + //! Swizzle the rectangular tile defined by the iterdomains corresponding + //! to the 2 given indices. + TensorView* swizzle(Swizzle2DType swizzle_type, int x, int y); + // WARNING: rFactor does not return this TensorView, ir returns a new // tensorview consumed by this! // @@ -455,10 +461,19 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! More detail on usage see [WarpMmaSwizzler] in scheduler/mma_utils.h . void applyMmaSwizzle(MmaOptions options); + //! Returns if this tensor view has swizzle operator on its tensor domain. + //! This is the temporary flag for indicating that the new swizzle + //! implementation is used and will be removed in follow ups. + bool hasSwizzleOp() const { + return has_swizzle_op_; + } + friend TORCH_CUDA_CU_API TransformPropagator; + friend TORCH_CUDA_CU_API MostInlinedTransformPropagator; friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; - friend ComputeAt; + friend TORCH_CUDA_CU_API InlinePropagator; + friend TORCH_CUDA_CU_API MaxProducerPosUpdater; friend class ir_utils::TVDomainGuard; friend TORCH_CUDA_CU_API void groupReductions( const std::vector&); @@ -501,6 +516,11 @@ class TORCH_CUDA_CU_API TensorView : public Val { // data, so we want to pass the data value as a standard kernel argument // value. bool cpu_scalar_ = false; + + //! Indicates if this tensor view has swizzle operator on its tensor domain. + //! This is the temporary flag for indicating that the new swizzle + //! implementation is used and will be removed in follow ups. + bool has_swizzle_op_ = false; }; //! A simple TensorView builder diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index f4a8ea1fe36d1..63cdfc19f126f 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -204,7 +204,9 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr { GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner); - size_t numReductions() const { + //! Number of expressions grouped horizontally. It does not reflect + //! iteration grouping. + size_t numExprs() const { return reduction_op_types_.size(); } @@ -231,7 +233,9 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr { bool sameAs(const Statement* other) const override; private: + //! Reduction ops of grouped reductions const std::vector reduction_op_types_; + //! Initial values of grouped reductions const std::vector init_vals_; //! True if using the fused reduction kernel bool is_allreduce_ = false; @@ -958,6 +962,13 @@ class TORCH_CUDA_CU_API IterDomain : public Val { return parallel_type_ == ParallelType::Mma; } + //! Applies 2D swizzle on a rectangular tile defined by + //! a pair of iterdomains. + static std::pair swizzle( + Swizzle2DType swizzle_type, + IterDomain* in_x, + IterDomain* in_y); + bool isMmaSwizzled() const { return is_mma_swizzled_; } @@ -1161,6 +1172,10 @@ class TORCH_CUDA_CU_API TensorDomain : public Val { // Reorder axes according to map[old_pos] = new_pos void reorder(const std::unordered_map& old2new); + //! Applies 2D swizzle on a rectangular tile defined by + //! a pair of iterdomains contained in this domain. + void swizzle(Swizzle2DType swizzle_type, int x, int y); + // Transform TensorView according to merge and split transformations TensorDomain* view( const std::vector>& transforms); @@ -1291,6 +1306,56 @@ class TORCH_CUDA_CU_API Merge : public Expr { IterDomain* const inner_ = nullptr; }; +//! Applies 2D swizzles on a rectangular tile defined by 2 iterdomains. +class TORCH_CUDA_CU_API Swizzle2D : public Expr { + public: + Swizzle2D( + IrBuilderPasskey, + IterDomain* out_x, + IterDomain* out_y, + IterDomain* in_x, + IterDomain* in_y, + Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle); + + Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner); + + IterDomain* outX() const { + return out_x_; + } + + IterDomain* outY() const { + return out_y_; + } + + IterDomain* inX() const { + return in_x_; + } + + IterDomain* inY() const { + return in_y_; + } + + const auto& swizzleType() const { + return swizzle_type_; + } + + bool sameAs(const Statement* other) const override; + + private: + // Output iterdomain pair corresponding + // to the original input iterdomain pair. + IterDomain* const out_x_ = nullptr; + IterDomain* const out_y_ = nullptr; + + // Input iterdomain pair. + IterDomain* const in_x_ = nullptr; + IterDomain* const in_y_ = nullptr; + + // The type of predefined 1-to-1 functions + // used for swizzling math. + Swizzle2DType swizzle_type_; +}; + //! Integer value which has a special name //! //! These could be: diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 0e8290fffb479..ea89a97a6a721 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -404,7 +404,7 @@ void IrPrinter::handle(const ReductionOp* rop) { void IrPrinter::handle(const GroupedReductionOp* grouped_rop) { indent() << "Grouped reduction(\n"; ++indent_size_; - for (const auto i : c10::irange(grouped_rop->numReductions())) { + for (const auto i : c10::irange(grouped_rop->numExprs())) { indent() << grouped_rop->output(i) << " = reduction( " << grouped_rop->input(i) << ", op = " << grouped_rop->getReductionOpType(i) @@ -472,6 +472,18 @@ void IrPrinter::handle(const Merge* m) { os_ << "\n"; } +void IrPrinter::handle(const Swizzle2D* s) { + os_ << s->swizzleType() << "(2D): "; + handle(s->inX()); + os_ << " , "; + handle(s->inY()); + os_ << " -> "; + handle(s->outX()); + os_ << " , "; + handle(s->outY()); + os_ << "\n"; +} + void IrPrinter::handle(const TransposeOp* top) { indent() << top->out() << " = transpose( " << top->in() << " )\n"; } @@ -666,7 +678,7 @@ void IrPrinter::handle(const kir::GridReduction* node) { void IrPrinter::handle(const kir::GroupedGridReduction* node) { indent() << "Grouped grid reduction(\n"; ++indent_size_; - for (const auto i : c10::irange(node->numReductions())) { + for (const auto i : c10::irange(node->numExprs())) { indent(); handle(node->output(i)); os_ << " = " @@ -691,7 +703,7 @@ void IrPrinter::handle(const kir::GroupedGridReduction* node) { os_ << "nullptr"; } os_ << "\n"; - for (const auto i : c10::irange(node->numReductions())) { + for (const auto i : c10::irange(node->numExprs())) { indent() << kTab << ".reduction_buffer="; handle(node->reduction_buffers().at(i)->buffer()); os_ << "\n"; @@ -781,6 +793,51 @@ void IrPrinter::handle(const kir::AllocateFusedReduction* node) { os_ << ")\n"; } +void IrPrinter::handle(const kir::IntPair* node) { + if (print_inline_) { + if (node->definition()) { + handle(node->definition()); + return; + } + } + os_ << "iPair" << varName(node); +} + +void IrPrinter::handle(const kir::Swizzle2DInt* node) { + if (!print_inline_) { + indent(); + handle(node->out()); + os_ << " = "; + } + + os_ << node->swizzleType() << "2D("; + handle(node->inX()); + os_ << ","; + handle(node->inY()); + os_ << ")"; +} + +void IrPrinter::handle(const kir::PairSelect* node) { + if (!print_inline_) { + indent(); + handle(node->out()); + os_ << " = "; + } + + handle(node->in()); + + switch (node->selection()) { + case kir::PairSelect::Selection::X: + os_ << ".x"; + break; + case kir::PairSelect::Selection::Y: + os_ << ".y"; + break; + default: + break; + } +} + void IrTransformPrinter::handle(Fusion* f) { auto all_vals = f->usedMathVals(); diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index c960d3975965c..6dbb2a646acd6 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -100,6 +100,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { void handle(const kir::Predicate*) final; void handle(const kir::TensorIndex*) final; + void handle(const kir::IntPair*) final; void handle(const kir::GridBroadcast*) final; void handle(const kir::GridReduction*) final; @@ -114,11 +115,14 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { void handle(const kir::InitMagicZero*) final; void handle(const kir::UpdateMagicZero*) final; void handle(const kir::AllocateFusedReduction*) final; + void handle(const kir::Swizzle2DInt*) final; + void handle(const kir::PairSelect*) final; // IR math printer overrides these to prevent them from printing, keep // override void handle(const Split*) override; void handle(const Merge*) override; + void handle(const Swizzle2D*) override; void print_inline(const Statement* stmt) { bool prev = print_inline_; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 64c4582220b2e..28545d04da640 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -460,7 +460,7 @@ bool GroupedReductionOp::sameAs(const Statement* other) const { return false; } - for (const auto i : c10::irange(numReductions())) { + for (const auto i : c10::irange(numExprs())) { if (!initVal(i)->sameAs(grouped_rop->initVal(i))) { return false; } @@ -1290,6 +1290,40 @@ std::pair IterDomain::stridedSplit(int factor) { return split_out; } +std::pair IterDomain::swizzle( + Swizzle2DType swizzle_type, + IterDomain* in_x, + IterDomain* in_y) { + TORCH_CHECK( + !in_x->extent()->isZeroInt() && !in_y->extent()->isZeroInt(), + "Invalid swizzling of a empty dimension."); + + // TODO: reduction check on swizzle: + TORCH_CHECK( + !in_x->isReduction() && !in_y->isReduction(), + "swizzled reduction not yet supported"); + + for (auto input : InputsOf::outputs(in_x->fusion(), {in_x, in_y})) { + TORCH_CHECK( + !input->as()->isBroadcast(), + "swizzling broadcast axes not yet supported"); + } + + // TODO: gather and shift check on swizzle + TORCH_INTERNAL_ASSERT( + !in_x->isGather() && !in_y->isGather(), + "Swizzled gather not yet supported"); + + IterDomain* out_x = IterDomainBuilder(in_x).build(); + + IterDomain* out_y = IterDomainBuilder(in_y).build(); + + IrBuilder::create( + in_x->container(), out_x, out_y, in_x, in_y, swizzle_type); + + return std::make_pair(out_x, out_y); +} + // TODO: We should change parallelize interface to be on tensorview or at least // vectorize should be done on tensorview. This would let us check that we don't // vectorize to the left of the computeAt domain, and could allow us to do some @@ -1300,10 +1334,11 @@ void IterDomain::parallelize(ParallelType t) { return; } - if (t == ParallelType::Unroll || isParallelTypeVectorize(t)) { + if (t == ParallelType::Unroll || isParallelTypeVectorize(t) || + t == ParallelType::Group) { TORCH_CHECK( start()->isZeroInt() && extent()->isConstScalar(), - "Vectorization, unrolling, and unswitching are only supported with start = 0 and extent as a const int, but got ", + "Vectorization, unrolling, unswitching and grouping are only supported with start = 0 and extent as a const int, but got ", "a start of ", start(), " and extent ", @@ -1311,6 +1346,13 @@ void IterDomain::parallelize(ParallelType t) { " ."); } + if (t == ParallelType::Group) { + TORCH_CHECK( + getIterType() == IterType::Iteration, + "Grouping IterDomain of non Iteration type is not allowed. ", + getIterType()); + } + if (isMmaSwizzled()) { // Mma swizzled axes represent data representation within a warp // so only allow updates that keep the parallelization within @@ -1748,6 +1790,35 @@ std::vector TensorDomain::orderedAs( return reordered_domain; } +void TensorDomain::swizzle(Swizzle2DType swizzle_type, int x, int y) { + TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim domain"); + + TORCH_CHECK( + x >= 0 && (unsigned int)x < nDims(), + "Invalid swizzle detected, either one or both axes are outside of TensorView's range."); + + TORCH_CHECK( + y >= 0 && (unsigned int)y < nDims(), + "Invalid swizzle detected, either one or both axes are outside of TensorView's range."); + + IterDomain* axis_x = axis(x); + IterDomain* axis_y = axis(y); + + IterDomain* axis_out_x = nullptr; + IterDomain* axis_out_y = nullptr; + + std::tie(axis_out_x, axis_out_y) = + IterDomain::swizzle(swizzle_type, axis_x, axis_y); + + domain_.erase(domain_.begin() + x); + domain_.insert(domain_.begin() + x, axis_out_x); + + domain_.erase(domain_.begin() + y); + domain_.insert(domain_.begin() + y, axis_out_y); + + resetDomains(); +} + std::vector TensorDomain::noReductions( const std::vector& td) { size_t size_out = 0; @@ -1962,6 +2033,46 @@ bool Merge::sameAs(const Statement* other) const { return Expr::sameAs(other); } +Swizzle2D::Swizzle2D( + IrBuilderPasskey passkey, + IterDomain* out_x, + IterDomain* out_y, + IterDomain* in_x, + IterDomain* in_y, + Swizzle2DType swizzle_type) + : Expr(passkey, ExprType::Swizzle2D), + out_x_{out_x}, + out_y_{out_y}, + in_x_{in_x}, + in_y_{in_y}, + swizzle_type_(swizzle_type) { + addOutput(out_x); + addOutput(out_y); + addInput(in_x); + addInput(in_y); +} + +bool Swizzle2D::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + if (!other->isA()) { + return false; + } + if (!(swizzle_type_ == other->as()->swizzle_type_)) { + return false; + } + return Expr::sameAs(other); +} + +Swizzle2D::Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner) + : Expr(src, ir_cloner), + out_x_(ir_cloner->clone(src->out_x_)), + out_y_(ir_cloner->clone(src->out_y_)), + in_x_(ir_cloner->clone(src->in_x_)), + in_y_(ir_cloner->clone(src->in_y_)), + swizzle_type_(src->swizzle_type_) {} + NamedScalar::NamedScalar( IrBuilderPasskey passkey, std::string name, diff --git a/torch/csrc/jit/codegen/cuda/ir_printer.h b/torch/csrc/jit/codegen/cuda/ir_printer.h index 91d07b76b8050..2cc0177787fb1 100644 --- a/torch/csrc/jit/codegen/cuda/ir_printer.h +++ b/torch/csrc/jit/codegen/cuda/ir_printer.h @@ -32,6 +32,7 @@ class TORCH_CUDA_CU_API IrMathPrinter : public IrPrinter { void handle(const Split* const) override {} void handle(const Merge* const) override {} + void handle(const Swizzle2D* const) override {} void handle(Fusion* f) override { IrPrinter::handle(f); diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 98912c425c5a0..b371e51b245a8 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -529,6 +529,22 @@ TORCH_CUDA_CU_API std::vector consumerValsOf(Val* val) { return uniqueEntries(consumer_vals); } +// Return immediate siblings of val +TORCH_CUDA_CU_API std::vector siblingValsOf(Val* val) { + std::vector sibling_vals; + auto def = val->definition(); + if (def != nullptr) { + auto outs = def->outputs(); + for (auto sibling_val : outs) { + if (sibling_val == val) { + continue; + } + sibling_vals.emplace_back(sibling_val); + } + } + return sibling_vals; +} + // Return immediate producers of val TORCH_CUDA_CU_API std::vector producerValsOf( const std::vector& vals) { @@ -556,22 +572,21 @@ TORCH_CUDA_CU_API std::vector consumerValsOf( } std::vector producerTvsOf(TensorView* tv) { - if (tv->definition() == nullptr) { - return {}; - } - auto producer_vals = - ir_utils::filterByType(tv->definition()->inputs()); - return uniqueEntries( - {producer_vals.begin(), producer_vals.end()}); + auto producer_vals = producerValsOf(tv); + auto producer_tvs = ir_utils::filterByType(producer_vals); + return {producer_tvs.begin(), producer_tvs.end()}; } std::vector consumerTvsOf(TensorView* tv) { - std::vector consumer_tvs; - for (auto use_expr : tv->uses()) { - auto outputs = ir_utils::filterByType(use_expr->outputs()); - consumer_tvs.insert(consumer_tvs.end(), outputs.begin(), outputs.end()); - } - return uniqueEntries(consumer_tvs); + auto consumer_vals = consumerValsOf(tv); + auto consumer_tvs = ir_utils::filterByType(consumer_vals); + return {consumer_tvs.begin(), consumer_tvs.end()}; +} + +TORCH_CUDA_CU_API std::vector siblingTvsOf(TensorView* tv) { + auto sibling_vals = siblingValsOf(tv); + auto sibling_tvs = ir_utils::filterByType(sibling_vals); + return {sibling_tvs.begin(), sibling_tvs.end()}; } std::vector producerTvsOf(const std::vector& tvs) { @@ -752,7 +767,7 @@ Val* getReductionInitValOf(TensorView* tv) { init = rop->init(); } else if (auto grop = dynamic_cast(def)) { int output_idx = -1; - for (const auto i : c10::irange(grop->numReductions())) { + for (const auto i : c10::irange(grop->numExprs())) { if (tv == grop->output(i)) { output_idx = static_cast(i); break; diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.h b/torch/csrc/jit/codegen/cuda/ir_utils.h index b4c96ae147872..dd96eda69d608 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.h +++ b/torch/csrc/jit/codegen/cuda/ir_utils.h @@ -181,6 +181,16 @@ TORCH_CUDA_CU_API std::vector producerValsOf(Val* val); // code. TORCH_CUDA_CU_API std::vector consumerValsOf(Val* val); +// Return immediate siblings of val, this function can be used on any Val and +// will return siblings through Exprs. +// +// Warning: returned val's are not guaranteed to be between fusion inputs and +// outputs. This function simply uses val->definition() or val->uses() which is +// limited to not go through fusion inputs/outputs, but if on a path that isn't +// strictly between fusion inputs/outputs, it could effectively return dead +// code. +TORCH_CUDA_CU_API std::vector siblingValsOf(Val* val); + // Return immediate producers of vals, this function can be used on any vals and // will return producers through Exprs. // @@ -223,6 +233,16 @@ TORCH_CUDA_CU_API std::vector producerTvsOf(TensorView* tv); // code. TORCH_CUDA_CU_API std::vector consumerTvsOf(TensorView* tv); +// Return immediate siblings of tv, this function will return all immediate +// siblings of tv through Exprs. +// +// Warning: returned tv's are not guaranteed to be between fusion inputs and +// outputs. This function simply uses tv->definition() or tv->uses() which is +// limited to not go through fusion inputs/outputs, but if on a path that isn't +// strictly between fusion inputs/outputs, it could effectively return dead +// code. +TORCH_CUDA_CU_API std::vector siblingTvsOf(TensorView* tv); + // Return immediate producers of tvs, this function will return all immediate // producers of tvs through Exprs. // diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index eddb49b6b2fc3..736eca4aa3e49 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -114,6 +114,57 @@ UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) "IR type only valid for Kernel container."); } +namespace { + +bool isIntegralScalar(const Val* val) { + return val->isScalar() && val->getDataType().has_value() && + isIntegralType(val->getDataType().value()); +} + +} // namespace + +IntPair::IntPair(IrBuilderPasskey passkey) + : Val(passkey, ValType::IntPair, DataType::Index) {} + +PairSelect::PairSelect( + IrBuilderPasskey passkey, + Val* out, + IntPair* in, + PairSelect::Selection selection) + : Expr(passkey, ExprType::PairSelect), + out_{out}, + in_{in}, + selection_(selection) { + addOutput(out); + addInput(in); + TORCH_INTERNAL_ASSERT(isIntegralScalar(out), "Integer only for this op"); +} + +Swizzle2DInt::Swizzle2DInt( + IrBuilderPasskey passkey, + IntPair* out, + Val* in_x, + Val* in_y, + Val* extent_x, + Val* extent_y, + Swizzle2DType swizzle_type) + : Expr(passkey, ExprType::Swizzle2DInt), + out_{out}, + in_x_{in_x}, + in_y_{in_y}, + extent_x_(extent_x), + extent_y_(extent_y), + swizzle_type_(swizzle_type) { + TORCH_INTERNAL_ASSERT(isIntegralScalar(in_x), "Integer only for this op"); + TORCH_INTERNAL_ASSERT(isIntegralScalar(in_y), "Integer only for this op"); + + addOutput(out); + addInput(in_x); + addInput(in_y); + addInput(extent_x); + addInput(extent_y); +} + void Scope::insert(std::vector::const_iterator pos, Expr* expr) { exprs_.insert(pos, expr); } @@ -475,6 +526,7 @@ GroupedGridReduction::GroupedGridReduction( Allocate* sync_buffer, Val* entrance_index, Val* entrances, + Val* buffer_stride, bool is_allreduce) : GroupedReductionOp( passkey, @@ -487,7 +539,8 @@ GroupedGridReduction::GroupedGridReduction( reduction_buffers_(std::move(reduction_buffers)), sync_buffer_(sync_buffer), entrance_index_(entrance_index), - entrances_(entrances) { + entrances_(entrances), + buffer_stride_(buffer_stride) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 730e0eb61804b..396ed9fb5f41b 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -587,6 +587,7 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { Allocate* sync_buffer, Val* entrance_index, Val* entrances, + Val* buffer_stride, bool is_allreduce = false); const std::vector& reduction_buffers() const { @@ -611,6 +612,10 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { return entrances_; } + Val* buffer_stride() const { + return buffer_stride_; + } + const ParallelTypeBitmap& threadPredicate() const { return thread_predicate_; } @@ -628,6 +633,8 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { ParallelTypeBitmap thread_predicate_; Val* entrance_index_ = nullptr; Val* entrances_ = nullptr; + // Stride of reduction buffers + Val* buffer_stride_ = nullptr; }; //! Grid broadcast operation @@ -762,6 +769,102 @@ class TORCH_CUDA_CU_API AllocateFusedReduction final : public Expr { Expr* grid_expr_ = nullptr; }; +//! An IR node consisting of a pair of integers +//! to facilitate definition of 2D swizzle operators. +//! All swizzle 2D ops takes two inputs and outputs +//! an integer pair. +//! TODO: +//! currently this IR node is only allowed as input +//! to the new PairSelect node. In follow ups would +//! possibly build out to support out of line +//! definition of the pair alone. +class TORCH_CUDA_CU_API IntPair : public Val { + public: + IntPair(IrBuilderPasskey passkey); +}; + +//! An IR node marking selection of first or second +//! value from a pair of integers, e.g.: +//! Pair(X,Y) -> X or Y. +//! This IR node is used to facilitate generation +//! of inline 2D swizzle math. +class TORCH_CUDA_CU_API PairSelect : public Expr { + public: + //! Indicates which value from the input + //! integer pair to output. + enum class Selection { X = 0, Y }; + + PairSelect(IrBuilderPasskey, Val* out, IntPair* in, Selection selection); + + Val* out() const { + return out_; + } + + IntPair* in() const { + return in_; + } + + auto selection() const { + return selection_; + } + + private: + Val* const out_ = nullptr; + IntPair* const in_ = nullptr; + Selection selection_; +}; + +//! An integer IR node that will be generated +//! using custom integer swizzle functions +//! from the cuda runtime functions. +//! Most supported swizzle functions require +//! the sizes of each dimension defined so +//! all operators will take the extents as inputs. +class TORCH_CUDA_CU_API Swizzle2DInt : public Expr { + public: + Swizzle2DInt( + IrBuilderPasskey, + IntPair* out, + Val* in_x, + Val* in_y, + Val* extent_x, + Val* extent_y, + Swizzle2DType swizzle_type); + + IntPair* out() const { + return out_; + } + + Val* inX() const { + return in_x_; + } + + Val* inY() const { + return in_y_; + } + + Val* extentX() const { + return extent_x_; + } + + Val* extentY() const { + return extent_y_; + } + + const auto& swizzleType() const { + return swizzle_type_; + } + + private: + IntPair* const out_ = nullptr; + + Val* const in_x_ = nullptr; + Val* const in_y_ = nullptr; + Val* const extent_x_ = nullptr; + Val* const extent_y_ = nullptr; + Swizzle2DType swizzle_type_; +}; + } // namespace kir } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp index a64b07da4a053..665e8d81532e8 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.cpp @@ -45,6 +45,46 @@ void IrVisitor::handle(IfThenElse* ite) { scope_exprs_.pop_back(); } +std::vector ConstIrVisitor::handle( + const std::vector& exprs) { + exprs_ = exprs; + for (auto expr : exprs) { + handle(expr); + } + return exprs_; +} + +void ConstIrVisitor::handle(const ForLoop* fl) { + for_loops_.push_back(fl); + scope_.push_back(&fl->body()); + scope_exprs_.push_back(fl); + auto body_exprs = fl->body().exprs(); + for (auto expr : body_exprs) { + handle(expr); + } + scope_exprs_.pop_back(); + scope_.pop_back(); + for_loops_.pop_back(); +} + +void ConstIrVisitor::handle(const IfThenElse* ite) { + scope_exprs_.push_back(ite); + scope_.push_back(&ite->thenBody()); + auto then_exprs = ite->thenBody().exprs(); + for (auto expr : then_exprs) { + handle(expr); + } + scope_.pop_back(); + + scope_.push_back(&ite->elseBody()); + auto else_exprs = ite->elseBody().exprs(); + for (auto expr : else_exprs) { + handle(expr); + } + scope_.pop_back(); + scope_exprs_.pop_back(); +} + std::vector ExprMutator::mutate(bool reverse_order) { if (insertions_.empty() && replacements_.empty() && removal_.empty()) { return exprs_; diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h b/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h index d665c4a6fdf53..139b4c37d45f1 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h @@ -45,6 +45,24 @@ class TORCH_CUDA_CU_API IrVisitor : public OptOutDispatch { std::vector exprs_; }; +// Const version of IrVisitor +class TORCH_CUDA_CU_API ConstIrVisitor : public OptOutConstDispatch { + public: + std::vector handle(const std::vector& expr); + + protected: + using OptOutConstDispatch::handle; + + virtual void handle(const ForLoop*) override; + virtual void handle(const IfThenElse*) override; + + protected: + std::vector for_loops_; + std::vector scope_; + std::vector scope_exprs_; + std::vector exprs_; +}; + // Base Expr Mutator class that visits all nodes with IrVisitor, and then // inserts new expressions, replaces expressions based on insertion/replace // maps provided or removes existing expressions. These replacement diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 69307bcae68bf..b8258da0b2342 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -273,6 +273,14 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { // created. vectorized_accesses_ and vectorized_set_info_ are filled. validateAndCollectVectorizeInfo(fusion_); + // Depends on ComputeAtMap and HaloInfo. + validateAndConvertIterDomainGrouping(fusion_); + + // Assumes all grouped reductions are convered to + // GroupedReductionOp, which is done by + // validateAndConvertIterDomainGrouping + validateGroupedReductions(fusion_); + // Depends on thread_pred_map_, validates parallelization collects which // tensor views need WAR or RAW syncs sync_map_.build(fusion_); diff --git a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp index ac1272c929af1..08517c441ebfd 100644 --- a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp @@ -19,6 +19,87 @@ namespace cuda { namespace { +//! Checks that the current loop nest is not realizing a serial +//! broadcast so that each index of producer buffer will only +//! be visited once, which is the only case where aggressive +//! inner sharing is valid. +//! +bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) { + //! Note: see issue #1785: + //! serial broadcast resolution doesn't only happen to + //! immediate producers of broadcast ops. We can also have + //! example: + //! T1[I,B] = broadcast(T0[I]]) + //! T3[I,I] = T1[I,B] + T2[I,I] + //! T4[I,I] = T3[I,I] + //! and generates the following loop: + //! alloc T0[4] + //! For i in 0..3 + //! T0[...] = + //! + //! For j in 0...X: + //! alloc T3[4] + //! for k in 0..3: + //! alloc T1[1] + //! T1[0] = T0[k] // <- This is actually a broadcast resolution + //! T3[k] = T1[0] + T2[...] + //! T4[...] = T3[...] + //! + //! In this case we are actually visiting each pixel of T0 in each iteration + //! of the j loop while T1 was the broadcasted tensor causing this reuse. + //! + //! The current version of checking covers this scenario by checking the root + //! ids of the consumer concrete loop id's. Any time a local tensor like T0 + //! appears in a re-use scenario like above, we should see a serial loop id + //! that was derived from some root id that doesn't concretely map to T0's + //! domain. + + // Serial concrete loop id's that cover consumer's iter domain. + std::vector consumer_serial_loop_concrete_ids; + + for (auto consumer_leaf_id : consumer->domain()->domain()) { + auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( + consumer_leaf_id, IdMappingMode::LOOP); + + // Check for any serial loop id with non-trivial extent + if (!concrete_loop_id->isThread() && + !concrete_loop_id->extent()->isOneInt()) { + consumer_serial_loop_concrete_ids.push_back(concrete_loop_id); + } + } + + // Collect the root id's that the serial loop iterdomain + // are transformed from. + auto serial_loop_roots = InputsOf::outputs( + FusionGuard::getCurFusion(), consumer_serial_loop_concrete_ids); + + // Collect exact concrete id's in producer's root domain + std::unordered_set producer_exact_concrete_root_ids; + auto producer_root = + TensorDomain::noReductions(producer->getMaybeRFactorDomain()); + std::transform( + producer_root.begin(), + producer_root.end(), + std::inserter( + producer_exact_concrete_root_ids, + producer_exact_concrete_root_ids.begin()), + ir_utils::caMapExactConcreteId); + + // Check if serial loop roots indexes any exact root id's that + // is not within the set of producer's root exact id's. These + // id's will imply that the same producer pixel is accessed + // in multiple iterations of the materialized serial loop. + for (auto serial_loop_root : + ir_utils::filterByType(serial_loop_roots)) { + if (!producer_exact_concrete_root_ids.count( + ir_utils::caMapExactConcreteId(serial_loop_root))) { + return true; + } + } + + return false; +} + //! Get string representation of Allocate size for symbolic comparison //! //! TODO: Some expr simplifications could also be helpful @@ -541,49 +622,6 @@ class BufferUseDefInfo { current_pos_ = -1; } - //! Checks that the current loop nest is not realizing a serial - //! broadcast so that each index of producer buffer will only - //! be visited once. - bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) { - auto producer_root = - TensorDomain::noReductions(producer->getMaybeRFactorDomain()); - auto consumer_root = - TensorDomain::noReductions(consumer->getMaybeRFactorDomain()); - - if (producer_root.size() != consumer_root.size()) { - // This case would be a single broadcast or a single reduce - // which wouldn't be a broadcast resolution - return true; - } - - std::vector serial_ids; - std::copy_if( - producer->domain()->domain().begin(), - producer->domain()->domain().end(), - std::back_inserter(serial_ids), - [](IterDomain* id) { return !id->isThread(); }); - - auto serial_producer_roots = - InputsOf::outputs(FusionGuard::getCurFusion(), serial_ids); - auto serial_root_id = - ir_utils::filterByType(serial_producer_roots); - std::unordered_set serial_producer_root_set( - serial_root_id.begin(), serial_root_id.end()); - - for (const auto idx : c10::irange(producer_root.size())) { - if (producer_root[idx]->isBroadcast() && - !consumer_root[idx]->isBroadcast()) { - // Check if this broadcast contributed to any serial - // scheduled iterdomains: - if (serial_producer_root_set.count(producer_root[idx])) { - return false; - } - } - } - - return true; - } - // Iterate over the inputs and outputs of exprs and update // the liveness info of local buffers if applicaable. void collectLivenessInfo(const Expr* expr) { @@ -599,7 +637,7 @@ class BufferUseDefInfo { for (auto input_tv : ir_utils::filterByType(expr->inputs())) { auto maybe_alloc_info = getMaybeAllocInfoFromTV(input_tv); if (maybe_alloc_info.has_value()) { - if (isSerialBroadcastResolution(input_tv, out_tv)) { + if (!isSerialBroadcastResolution(input_tv, out_tv)) { maybe_alloc_info.value()->inner_live_interval->markRead(current_pos_); } else { // Disable inner alias info for this buffer, since line number based @@ -1016,6 +1054,13 @@ class AllocateReuseModifier { auto this_tv = alloc_info->alloc_expr->buffer()->as(); auto reuse_tv = to_reuse->alloc_expr->buffer()->as(); + // Aggressively disable inner sharing for swizzled tvs since + // the indexing order is in general not tractable. + // But outer sharing should still apply. + if (this_tv->hasSwizzleOp() || reuse_tv->hasSwizzleOp()) { + return false; + } + // Check the values in between the two buffers. auto vals_between_this_and_reuse = DependencyCheck::getAllValsBetween({this_tv}, {reuse_tv}); @@ -1070,6 +1115,9 @@ class AllocateReuseModifier { InPlaceSharingInfo checkOpsInBetween(std::vector& all_used_vals) { InPlaceSharingInfo info; + std::unordered_set all_used_val_set( + all_used_vals.begin(), all_used_vals.end()); + for (auto val : all_used_vals) { if (auto tv = dynamic_cast(val)) { auto tv_def = tv->definition(); diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 66bbb54dbb214..cb47644a5a03e 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -137,14 +137,20 @@ void IndexLowering::handle(const ViewAsScalar* uop) { namespace { +struct GridCommWorkBufferSizeInfo { + // Size of overall buffer. Can be expanded for privatization + Val* size_of_privatized_buffer = nullptr; + // Size of single buffer. + Val* buffer_stride = nullptr; +}; + // Get the size of the temporary work buffer for grid communication, this can be // grid reduction, broadcast, or grid welford. -// expansion_factor can be optionally passed to expand the allocation -// size. For example, FusedReduction should double the work buffer size. -Val* getGridCommWorkBufferSize( +// The buffer is expanded for privatization when not persistent or grouped. +GridCommWorkBufferSizeInfo getGridCommWorkBufferSize( const TensorDomain* td, - const std::vector& for_loops = {}, - int expansion_factor = 1) { + const std::vector& for_loops, + bool is_persistent) { // The buffer size is the number of thread blocks multiplied by the // number of threads not used for reduction domains. // Note: Previously it was calculated based on the shape of the @@ -154,11 +160,7 @@ Val* getGridCommWorkBufferSize( // size if the parallel dimensions are exact, but otherwise, just // computing the buffer size based on the tensor shape isn't // sufficient since there could be extra threads/blocks. - TORCH_INTERNAL_ASSERT( - expansion_factor >= 1, "Invalid expansion factor: ", expansion_factor); - Val* buffer_size = expansion_factor == 1 - ? GpuLower::current()->kernel()->oneVal() - : IrBuilder::create(expansion_factor); + Val* size_of_single_buffer = GpuLower::current()->kernel()->oneVal(); for (auto pt : kParallelTypeThreads) { auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt); if (pt_dim == nullptr || pt_dim->isOneInt()) { @@ -171,29 +173,58 @@ Val* getGridCommWorkBufferSize( })) { continue; } - buffer_size = SimplifyingIrBuilder::mulExpr(buffer_size, pt_dim); + size_of_single_buffer = + SimplifyingIrBuilder::mulExpr(size_of_single_buffer, pt_dim); } - // All iteration domains require a separate entry in the buffer for re-entrant - // grid reductions. + // Expand the buffer for privatization. The buffer is expanded so + // that each non-reduction IterDomain uses a different part of the + // buffer. For persistent mode, this expansion is only done for + // grouped IterDomains. + + Val* size_of_privatized_buffer = size_of_single_buffer; + + // In persistent mode, if non-grouped no-reduction domain is used, + // double the buffer size to save a final grid sync + bool is_doubled = false; + for (auto fl : for_loops) { - if (fl->isTrivial()) { + // Buffer size of parallelized domains are already taken care + if (fl->isTrivial() || fl->iter_domain()->isReduction() || + fl->iter_domain()->isThread()) { continue; } - if (fl->iter_domain()->isThread()) { - // already accounted for. - continue; + // If persistent, i.e., allreduce, only IterDomains with + // ParallelType::Group are privatized + if (!is_persistent || + fl->iter_domain()->getParallelType() == ParallelType::Group) { + size_of_privatized_buffer = SimplifyingIrBuilder::mulExpr( + size_of_privatized_buffer, fl->iter_domain()->extent()); + } else if (is_persistent) { + is_doubled = true; } - buffer_size = - SimplifyingIrBuilder::mulExpr(buffer_size, fl->iter_domain()->extent()); } - return buffer_size; + if (is_doubled) { + size_of_privatized_buffer = SimplifyingIrBuilder::mulExpr( + size_of_privatized_buffer, IrBuilder::create(2)); + } + + GridCommWorkBufferSizeInfo info; + info.size_of_privatized_buffer = size_of_privatized_buffer; + info.buffer_stride = size_of_single_buffer; + if (is_doubled) { + info.buffer_stride = SimplifyingIrBuilder::mulExpr( + info.buffer_stride, IrBuilder::create(2)); + } + + return info; } Val* getGridSyncBufferSize( const TensorDomain* td, - const std::vector& for_loops = {}) { + const std::vector& for_loops, + bool is_persistent) { // See the comment above for getGridCommWorkBufferSize. Val* buffer_size = GpuLower::current()->kernel()->oneVal(); for (auto pt : kParallelTypeBIDs) { @@ -210,19 +241,21 @@ Val* getGridSyncBufferSize( buffer_size = SimplifyingIrBuilder::mulExpr(buffer_size, pt_dim); } - // All iteration domains require a separate semaphore for re-entrant grid - // reductions - for (auto fl : for_loops) { - if (fl->isTrivial()) { - continue; + // If not persistent, all iteration domains require a separate + // semaphore for re-entrant grid reductions + if (!is_persistent) { + for (auto fl : for_loops) { + if (fl->isTrivial()) { + continue; + } + if (fl->iter_domain()->isThread()) { + // already accounted for. + continue; + } + + buffer_size = SimplifyingIrBuilder::mulExpr( + buffer_size, fl->iter_domain()->extent()); } - if (fl->iter_domain()->isThread()) { - // already accounted for. - continue; - } - - buffer_size = - SimplifyingIrBuilder::mulExpr(buffer_size, fl->iter_domain()->extent()); } return buffer_size; @@ -338,13 +371,6 @@ void IndexLowering::handleGridReduction( "then the grid reduction. ", rop->toString()); - // When using the fused reduction in a loop, the global work buffer - // is double buffered to save global synchronizations. - auto is_within_a_loop = std::any_of( - out_domain->domain().begin(), - out_domain->domain().end(), - [](IterDomain* id) { return !isTrivialIterDomain(id); }); - // Use a unique buffer for work and sync flag when called within a // loop unless it's persistent. Grid all reduce means persistence is // required. However, not being a grid all reduce does not mean @@ -352,27 +378,26 @@ void IndexLowering::handleGridReduction( // required anywhere in the kernel, all grid reducitons are done in // a persistent manner, so all grid reductions should be consulted. // TODO: fix this - const bool privatize_buffer = !rop->isAllreduce(); - - const auto reduce_buffer = ir_utils::allocGlobalBufferForGridComm( - getGridCommWorkBufferSize( - out_domain, - privatize_buffer ? for_loops_ : std::vector(), - rop->isAllreduce() && is_within_a_loop ? 2 : 1), - out->dtype(), - false); - - const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm( - getGridSyncBufferSize( - out_domain, - privatize_buffer ? for_loops_ : std::vector()), - DataType::Int, - true); - - const auto entrance_ind = privatize_buffer + const bool is_persistent = rop->isAllreduce(); + const auto buffer_size_info = + getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent); + + auto work_buffer = allocateUniqueBuffer( + buffer_size_info.size_of_privatized_buffer, + out_tv->dtype(), + false, + out_tv, + work_buffer_map_); + + auto sync_buffer_size = + getGridSyncBufferSize(out_domain, for_loops_, is_persistent); + auto sync_buffer = allocateUniqueBuffer( + sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_); + + const auto entrance_ind = !is_persistent ? getEntranceLinIndGridReduce(for_loops_) : GpuLower::current()->kernel()->zeroVal(); - const auto n_entrances = privatize_buffer + const auto n_entrances = !is_persistent ? getEntranceCountGridReduce(for_loops_) : GpuLower::current()->kernel()->oneVal(); @@ -387,7 +412,7 @@ void IndexLowering::handleGridReduction( rop->init(), out, in, - reduce_buffer, + work_buffer, sync_buffer, entrance_ind, n_entrances, @@ -402,17 +427,11 @@ void IndexLowering::handleGridReduction( grid_reduction->setWritePredicate(rop->writePredicate()); } - pushBack(reduce_buffer); - pushBack(sync_buffer); pushBack(grid_reduction); GpuLower::current()->propagateExprInfo(rop, back()); if (rop->isAllreduce()) { - // When using the fused reduction, allocate the reduction object at - // the outer-most scope - auto fused_reduction_alloc_reduction = - IrBuilder::create(grid_reduction); - insertAtTopLevel(fused_reduction_alloc_reduction); + allocateUniqueFusedReduction(grid_reduction, out_tv); } } @@ -425,10 +444,10 @@ void IndexLowering::handle(const GroupedReductionOp* grouped_rop) { const bool has_block_reduce = out_domain->hasBlockReduction(); const bool has_grid_reduce = out_domain->hasGridReduction(); - std::vector indexed_outputs(grouped_rop->numReductions()); - std::vector indexed_inputs(grouped_rop->numReductions()); + std::vector indexed_outputs(grouped_rop->numExprs()); + std::vector indexed_inputs(grouped_rop->numExprs()); - for (const auto i : c10::irange(grouped_rop->numReductions())) { + for (const auto i : c10::irange(grouped_rop->numExprs())) { indexed_outputs.at(i) = lowerDstIndex(grouped_rop->output(i)); indexed_inputs.at(i) = lowerSrcIndex(grouped_rop->input(i), grouped_rop->output(i)); @@ -439,7 +458,7 @@ void IndexLowering::handle(const GroupedReductionOp* grouped_rop) { } else if (has_block_reduce) { handleBlockReduction(grouped_rop, indexed_outputs, indexed_inputs); } else { - for (const auto i : c10::irange(grouped_rop->numReductions())) { + for (const auto i : c10::irange(grouped_rop->numExprs())) { pushBack(IrBuilder::create( grouped_rop->getReductionOpType(i), indexed_outputs.at(i), @@ -496,41 +515,33 @@ void IndexLowering::handleGridReduction( "please use rfactor to do the serialized reduction first, ", "then the grid reduction."); - // When using the fused reduction in a loop, the global work buffer - // is double buffered to save global synchronizations. - auto is_within_a_loop = std::any_of( - out_domain->domain().begin(), - out_domain->domain().end(), - [](IterDomain* id) { return !isTrivialIterDomain(id); }); + const bool is_persistent = grouped_rop->isAllreduce(); + auto work_buf_size_info = + getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent); - const bool privatize_buffer = !grouped_rop->isAllreduce(); - - std::vector reduce_buffers; + std::vector work_buffers; std::transform( outputs.begin(), outputs.end(), - std::back_inserter(reduce_buffers), + std::back_inserter(work_buffers), [&](Val* output) { - return ir_utils::allocGlobalBufferForGridComm( - getGridCommWorkBufferSize( - out_domain, - privatize_buffer ? for_loops_ : std::vector(), - (grouped_rop->isAllreduce() && is_within_a_loop ? 2 : 1)), + return allocateUniqueBuffer( + work_buf_size_info.size_of_privatized_buffer, output->dtype(), - false); + false, + output->as()->view(), + work_buffer_map_); }); - const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm( - getGridSyncBufferSize( - out_domain, - privatize_buffer ? for_loops_ : std::vector()), - DataType::Int, - true); + auto sync_buffer_size = + getGridSyncBufferSize(out_domain, for_loops_, is_persistent); + auto sync_buffer = allocateUniqueBuffer( + sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_); - const auto entrance_ind = privatize_buffer + const auto entrance_ind = !is_persistent ? getEntranceLinIndGridReduce(for_loops_) : GpuLower::current()->kernel()->zeroVal(); - const auto n_entrances = privatize_buffer + const auto n_entrances = !is_persistent ? getEntranceCountGridReduce(for_loops_) : GpuLower::current()->kernel()->oneVal(); @@ -545,10 +556,11 @@ void IndexLowering::handleGridReduction( grouped_rop->initVals(), outputs, inputs, - reduce_buffers, + work_buffers, sync_buffer, entrance_ind, n_entrances, + work_buf_size_info.buffer_stride, grouped_rop->isAllreduce()); grid_reduction->setThreadPredicate(thread_pred); @@ -560,17 +572,11 @@ void IndexLowering::handleGridReduction( grid_reduction->setWritePredicate(grouped_rop->writePredicate()); } - for (auto reduce_buffer : reduce_buffers) { - pushBack(reduce_buffer); - } - pushBack(sync_buffer); pushBack(grid_reduction); GpuLower::current()->propagateExprInfo(grouped_rop, back()); if (grouped_rop->isAllreduce()) { - auto fused_reduction_alloc_reduction = - IrBuilder::create(grid_reduction); - insertAtTopLevel(fused_reduction_alloc_reduction); + allocateUniqueFusedReduction(grid_reduction, out_tv); } } @@ -654,40 +660,40 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { const auto out_tv = indexed_wop->out()->as()->view(); const auto out_domain = out_tv->domain(); - // Buffer allocation - // When using the fused reduction in a loop, the global work buffer - // is double buffered to save global synchronizations. - auto is_within_a_loop = std::any_of( - out_domain->domain().begin(), - out_domain->domain().end(), - [](IterDomain* id) { return !isTrivialIterDomain(id); }); - // TODO: See the comment on the same variable in handleGridReduction - const bool privatize_buffer = !indexed_wop->isAllreduce(); - - const auto work_buffer_size = getGridCommWorkBufferSize( - out_domain, - privatize_buffer ? for_loops_ : std::vector(), - indexed_wop->isAllreduce() && is_within_a_loop ? 2 : 1); - - const auto out_var_buffer = ir_utils::allocGlobalBufferForGridComm( - work_buffer_size, indexed_wop->outVar()->dtype(), false); - const auto out_avg_buffer = ir_utils::allocGlobalBufferForGridComm( - work_buffer_size, indexed_wop->outAvg()->dtype(), false); - const auto out_N_buffer = ir_utils::allocGlobalBufferForGridComm( - work_buffer_size, indexed_wop->outN()->dtype(), false); - - const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm( - getGridSyncBufferSize( - out_domain, - privatize_buffer ? for_loops_ : std::vector()), - DataType::Int, - true); - - const auto entrance_ind = privatize_buffer + const bool is_persistent = indexed_wop->isAllreduce(); + const auto buffer_size_info = + getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent); + + const auto work_buffer_size = buffer_size_info.size_of_privatized_buffer; + auto out_var_buffer = allocateUniqueBuffer( + work_buffer_size, + indexed_wop->outVar()->dtype(), + false, + indexed_wop->outVar()->as()->view(), + work_buffer_map_); + auto out_avg_buffer = allocateUniqueBuffer( + work_buffer_size, + indexed_wop->outAvg()->dtype(), + false, + indexed_wop->outAvg()->as()->view(), + work_buffer_map_); + auto out_N_buffer = allocateUniqueBuffer( + work_buffer_size, + indexed_wop->outN()->dtype(), + false, + indexed_wop->outN()->as()->view(), + work_buffer_map_); + + auto sync_buffer_size = + getGridSyncBufferSize(out_domain, for_loops_, is_persistent); + auto sync_buffer = allocateUniqueBuffer( + sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_); + + const auto entrance_ind = !is_persistent ? getEntranceLinIndGridReduce(for_loops_) : GpuLower::current()->kernel()->zeroVal(); - const auto n_entrances = privatize_buffer + const auto n_entrances = !is_persistent ? getEntranceCountGridReduce(for_loops_) : GpuLower::current()->kernel()->oneVal(); @@ -729,19 +735,13 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { GpuLower::current()->propagateExprInfo(indexed_wop, back()); } - pushBack(out_var_buffer); - pushBack(out_avg_buffer); - pushBack(out_N_buffer); - pushBack(sync_buffer); pushBack(grid_welford); GpuLower::current()->propagateExprInfo(indexed_wop, back()); if (indexed_wop->isAllreduce()) { // When using the fused reduction, allocate the reduction object at // the outer-most scope - auto fused_reduction_alloc_reduction = - IrBuilder::create(grid_welford); - insertAtTopLevel(fused_reduction_alloc_reduction); + allocateUniqueFusedReduction(grid_welford, out_tv); } } @@ -792,21 +792,24 @@ void IndexLowering::handle(const BroadcastOp* bop) { // Grid broadcast const auto out_domain = out_tv->domain(); - const auto broadcast_buffer = ir_utils::allocGlobalBufferForGridComm( - getGridCommWorkBufferSize(out_domain), out->dtype(), false); + const auto work_buffer_size = + getGridCommWorkBufferSize(out_domain, for_loops_, true) + .size_of_privatized_buffer; + + auto work_buffer = allocateUniqueBuffer( + work_buffer_size, out->dtype(), false, out_tv, work_buffer_map_); - const auto sync_buffer = ir_utils::allocGlobalBufferForGridComm( - getGridSyncBufferSize(out_domain), DataType::Int, true); + auto sync_buffer_size = getGridSyncBufferSize(out_domain, for_loops_, true); + auto sync_buffer = allocateUniqueBuffer( + sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_); auto grid_broadcast = IrBuilder::create( - indexed_expr, broadcast_buffer, sync_buffer); + indexed_expr, work_buffer, sync_buffer); if (bop->predicate()) { grid_broadcast->setPredicate(bop->predicate()); } - pushBack(broadcast_buffer); - pushBack(sync_buffer); pushBack(grid_broadcast); GpuLower::current()->propagateExprInfo(bop, back()); } @@ -837,6 +840,69 @@ void IndexLowering::generate(const std::vector& exprs) { } } +kir::Allocate* IndexLowering::allocateUniqueBuffer( + Val* buffer_size, + DataType dtype, + bool zero_init, + TensorView* out_tv, + std::unordered_map& alloc_map) { + // Return an existing allocation if exists + auto it = alloc_map.find(out_tv); + if (it != alloc_map.end()) { + return it->second; + } + + // No existing allocation found. Create a new one + auto new_buffer = + ir_utils::allocGlobalBufferForGridComm(buffer_size, dtype, zero_init); + + // Keep track of the allocation + alloc_map.emplace(out_tv, new_buffer); + + // A buffer may be used in both the unswitched paths, so it must be + // placed outside of the current scope. Simplying placing it at the + // top-level scope should work. + insertAtTopLevel(new_buffer); + + return new_buffer; +} + +void IndexLowering::allocateUniqueFusedReduction( + Expr* expr, + TensorView* out_tv) { + auto it = fused_reduction_map_.find(out_tv); + if (it != fused_reduction_map_.end()) { + return; + } + + kir::AllocateFusedReduction* fused_reduction_alloc_reduction = nullptr; + switch (expr->getExprType().value()) { + case ExprType::GridReduction: + fused_reduction_alloc_reduction = + IrBuilder::create( + expr->as()); + break; + case ExprType::GridWelford: + fused_reduction_alloc_reduction = + IrBuilder::create( + expr->as()); + break; + case ExprType::GroupedGridReduction: + fused_reduction_alloc_reduction = + IrBuilder::create( + expr->as()); + break; + default: + TORCH_INTERNAL_ASSERT(false, "Invalid expr: ", expr->toString()); + } + + fused_reduction_map_.emplace(out_tv, fused_reduction_alloc_reduction); + + // When using the fused reduction, allocate the reduction object at + // the outer-most scope + insertAtTopLevel(fused_reduction_alloc_reduction); +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h index 518b24a04c7a2..818dd98dca2b2 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_index.h @@ -76,6 +76,21 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { void handleGridWelford(WelfordOp* new_wop); + // Allocate a unique buffer for grid reductions and broadcast. A + // buffer is uniquely allocated for each output tensor of an + // expression. + kir::Allocate* allocateUniqueBuffer( + Val* buffer_size, + DataType dtype, + bool zero_init, + TensorView* out_tv, + std::unordered_map& alloc_map); + + // Allocate a fused reduction object uniquely for a given + // TensorView. Parameter expr is the expression corresponding to the + // fused reduction. + void allocateUniqueFusedReduction(Expr* expr, TensorView* out_tv); + private: std::vector lowered_exprs_; @@ -90,6 +105,13 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { // Track for loops to send to indexing. Similar to what's done in // kir::IrVisitor std::vector for_loops_; + + // Maps to keep track of allocated buffers and objects that must be + // allocated only once + std::unordered_map sync_buffer_map_; + std::unordered_map work_buffer_map_; + std::unordered_map + fused_reduction_map_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp index ad1610c40aa8f..837dfa24c5f48 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include #include @@ -268,6 +267,179 @@ IndexingParameters getNonGlobalInitialIndexParameters( return index_parameters; } +//! Initial index parameters for predicate, adjusts loop to indexing +//! may according to the information annotated on the loop nest. +//! +//! TODO: +//! This function is mostly copy pasted from previous implementation +//! at this step, further clean up is possible since: +//! 1. Much of the loop-to-ind adjustment will be issued from idgraph +//! 2. Much of the initial index logic could be shared across all +//! the 3 variants. +IndexingParameters getPredicateInitialIndexParameters( + const LoopIndexing& loop_indexing, + TensorView* consumer_tv, + kir::ForLoop* unswitch_or_vec_loop, + IterDomain* double_buffer_axis, + bool is_start_predicate) { + IndexingParameters index_parameters; + const auto& loops = loop_indexing.loops(); + const auto& loop_domains = loop_indexing.loopDomains(); + + // This shouldn't be needed. + TORCH_INTERNAL_ASSERT( + loops.size() <= loop_domains.size(), + "Loop domain didn't replay all loops"); + + std::unordered_map loop_to_ind_map; + + // Fill initial index with each forloop's index. + std::transform( + loops.begin(), + loops.end(), + std::inserter(loop_to_ind_map, loop_to_ind_map.begin()), + [](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); }); + + // Generate unswitch loop to index map. + if (unswitch_or_vec_loop != nullptr) { + // Vectorized predicates are different from unswitch. Unswitch predicates + // all loops within the unswitch (the outer most unswitch) are generated + // with loop->extent-1 as the index. With vectorized predicates, only the + // vectorized loop should be like this. + bool vectorized_pred = + unswitch_or_vec_loop->iter_domain()->getParallelType() == + ParallelType::Vectorize; + + bool within_unswitch = false; + + for (const auto loop_i : c10::irange(loops.size())) { + auto loop = loops[loop_i]; + auto loop_id = loop->iter_domain(); + auto loop_pt = loop_id->getParallelType(); + auto ref_id = loop_domains.at(loop_i); + + if (loop == unswitch_or_vec_loop) { + within_unswitch = true; + } + + if (within_unswitch) { + // Rely on the reference to check broadcasting. The for loop could be + // broadcasted on a constant value from an unroll split. Since reference + // may convert this to an iter domain, that for loop could be valid to + // generate predication from. + + // Note that loop->stop() is not used below. Instead, + // loop->iter_domain()->extent() is used, which is uniform + // across the mapped domains irrespective of halo. Predicates are + // compared with each to pick the most restrictive ones. The + // comparison is done by only using the offset, which is the + // term added to the index. So, the index term must be the + // same among all predicates, otherwise the comparison would + // be invalid. The effect by halo is added to the offset + // term. See getUnswitchStopOffset. + + if (ref_id->isBroadcast()) { + // Ignore indexing into broadcasted dimensions. + continue; + } else if (loop_id->isThread()) { + // When parallelized, if the loop stop is the same as the + // extent of the associated IterDomain, i.e., no extra + // iterations for halo, predicating with the threading index + // is sufficient for both the start and stop + // predicates. That isn't the case if the loop has halo, and + // in the case either the minimum and maximum values of the + // iteration domain needs to be used. + // + // Note: Better performance was obtained if using + // threadIdx in unswitch predicates was avoided. More + // specifically, in the Hdiff stencil example, instead of + // predicating with threadIdx.x for both the start and stop + // predicates, using zero and (blockDim.x - 1) for the start + // and stop predicates, respectively, resulted in less + // register pressure. The alternative codegen can be done by + // adding this to the first if condition: + // loop_id->isBlockDim(). This would not be a concern if the + // else part could be omitted, so canOmitElseClause should + // be used as well. + if (loop->stop() == loop_id->extent()) { + loop_to_ind_map[loop] = loop->start(); + } else if (is_start_predicate) { + loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal(); + } else { + // Note that the parallel dimension is used rather than + // loop-stop(). See the above comment. + loop_to_ind_map[loop] = + GpuLower::current()->parallelDimensionMap().get(loop_pt); + } + } else if (is_start_predicate) { + loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal(); + } else { + // Similar to the above, loop_id()->extent() is + // used here instead of loop->stop(). See the above comment. + loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr( + loop_id->extent(), GpuLower::current()->kernel()->oneVal()); + } + } + + // If a vectorized predicate, bail after the vectorized loop was found. + // Don't continue unswitching loops. + if (vectorized_pred && within_unswitch) { + break; + } + } + } + + // Modify trivial loops to use the loop start value. + // FIXME: eventually should be all lifted in idgraph. + for (const auto loop : loops) { + auto& idx = loop_to_ind_map.at(loop); + // If the loop is trivial, the loop index can only be the loop + // start value. + if (idx == loop->index() && loop->isTrivial()) { + idx = loop->start(); + } + } + + // Increment double buffer loop index + if (double_buffer_axis != nullptr) { + auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( + double_buffer_axis, loops, true); + if (db_loop != nullptr) { + auto loop_to_ind_map_it = loop_to_ind_map.find(db_loop); + TORCH_INTERNAL_ASSERT(loop_to_ind_map_it != loop_to_ind_map.end()); + auto cur_index = loop_to_ind_map_it->second; + // if cur_index is not the same as the index of db_loop, it must + // be true that that index has been modified to support + // unswitch. In that case, it is not necessary to move ahead the + // index for double buffering. + if (cur_index == db_loop->index()) { + loop_to_ind_map[db_loop] = SimplifyingIrBuilder::addExpr( + cur_index, GpuLower::current()->kernel()->oneVal()); + } + } + } + + // Convert loop-to-ind map to concrete-to-ind map + for (int loop_idx : c10::irange(loops.size())) { + auto loop = loops.at(loop_idx); + auto concrete_loop_domain = + ir_utils::caMapExactConcreteId(loop_domains.at(loop_idx)); + index_parameters.initial_concrete_id_index[concrete_loop_domain] = + loop_to_ind_map.at(loop); + } + + insertMagicZero( + loops, + loop_indexing.loopDomains(), + index_parameters.initial_concrete_id_index); + + // Derive the halo extents from the loop indexing result. + index_parameters.concrete_id_to_halo_extent = + GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing); + + return index_parameters; +} + } // namespace class LoopIndexingAnalysis { @@ -507,6 +679,16 @@ IterDomain* LoopIndexingAnalysis::concretizeAndVisitId(IterDomain* id) { } void LoopIndexingAnalysis::visitExpr(Expr* expr) { + if (auto swizzle2d = dynamic_cast(expr)) { + // Swizzle outputs are already forwarded through + // by exact CA map, so currently they are just + // ignored in the replay pass except + // that we want to note this node visited. + concretizeAndVisitId(swizzle2d->outX()); + concretizeAndVisitId(swizzle2d->outY()); + return; + } + // Current implementation just tries to // follow the exact behavior of reference replay // except that no expr was actually "replayed". @@ -715,6 +897,74 @@ IndexFromIdGraph getTensorIndexFromIdGraph( loop_indexing.loopDomains()); } +IndexFromIdGraph getPredicateIndexingFromIdGraph( + const std::vector& loops, + TensorView* consumer_tv, + kir::ForLoop* unswitch_or_vec_loop, + IterDomain* double_buffer_axis, + bool is_start_predicate) { + // Run replay pass on the loop nest to generate the deterministic + // traversal info from loop structure. + auto loop_indexing = + LoopIndexingAnalysis::fromLoopAndConsumer(loops, consumer_tv); + + // Bind initial index variables to the loop nodes and adjust + // according to loop and unswitch info. + auto index_parameters = getPredicateInitialIndexParameters( + loop_indexing, + consumer_tv, + unswitch_or_vec_loop, + double_buffer_axis, + is_start_predicate); + + // Run first backward traversal to generate + // loop nest based indexing math. + IndexCompute indexing( + index_parameters.initial_concrete_id_index, + index_parameters.zero_domains, + index_parameters.preferred_concrete_ids, + index_parameters.concrete_id_to_halo_extent); + + indexing.run(loop_indexing); + + // Map the concrete id indexing back to consumer tv + std::unordered_map index_update_map; + + // First collect all iterdomains in consumer transform history. + auto all_consumer_vals = DependencyCheck::getAllValsBetween( + {consumer_tv->getMaybeRFactorDomain().begin(), + consumer_tv->getMaybeRFactorDomain().end()}, + {consumer_tv->domain()->domain().begin(), + consumer_tv->domain()->domain().end()}); + + for (IterDomain* consumer_id : + ir_utils::filterByType(all_consumer_vals)) { + // Track the non-concrete id we were trying to bind index + // to, whether from producer or consumer. + auto exact_concrete_id = ir_utils::caMapExactConcreteId(consumer_id); + index_update_map[exact_concrete_id] = consumer_id; + } + + // No contiguity info is used in the predicate indexing pass, + // the predicate generation logic that uses the index math + // generated here will take contiguity into account. + ContigIDs contig_finder( + consumer_tv->domain()->domain(), + consumer_tv->getMaybeRFactorDomain(), + std::vector(consumer_tv->getMaybeRFactorDomain().size(), false), + {}); + + // Run second backward traversal to map back to the consumer_tv + auto target_indexing = indexing.updateIndexCompute( + consumer_tv->domain(), index_update_map, contig_finder); + + return IndexFromIdGraph( + target_indexing, + indexing, + index_parameters.initial_concrete_id_index, + loop_indexing.loopDomains()); +} + namespace { class LoopIndexingTraversal { @@ -892,6 +1142,139 @@ std::unordered_set LoopIndexing::getAllExactConcreteIdSet() const { return all_id_set; } +namespace { + +//! Returns true if id is mapped together with any id in +//! the vector ids by permissive compute at map. +bool isPermissivelyMappedWithAny(IterDomain* id, const std::vector& ids) { + return std::any_of(ids.begin(), ids.end(), [&](Val* val) { + return val->isA() && + GpuLower::current()->caMap()->areMapped( + id, val->as(), IdMappingMode::PERMISSIVE); + }); +} + +class LoopIndexingPreferredPathCompute : public IterVisitor { + public: + static std::unordered_set compute( + const TensorView* original_tv, + const LoopIndexing& loop_indexing, + bool use_replay_map, + const std::unordered_map& p2c_map) { + LoopIndexingPreferredPathCompute compute; + + auto all_concrete_ids = loop_indexing.getAllExactConcreteIdSet(); + + // Annotate all ids + auto all_original_ids = DependencyCheck::getAllValsBetween( + {original_tv->getMaybeRFactorDomain().begin(), + original_tv->getMaybeRFactorDomain().end()}, + {original_tv->domain()->domain().begin(), + original_tv->domain()->domain().end()}); + + for (auto original_id : + ir_utils::filterByType(all_original_ids)) { + auto mapped_id = original_id; + if (use_replay_map) { + auto c_id_it = p2c_map.find(original_id); + if (c_id_it == p2c_map.end()) { + continue; + } + mapped_id = c_id_it->second; + } + auto concrete_original_id = ir_utils::caMapExactConcreteId(mapped_id); + if (all_concrete_ids.count(concrete_original_id)) { + if (original_id->isBroadcast() || original_id->isReduction() || + original_id->isStride()) { + continue; + } + compute.preferred_path_.insert(concrete_original_id); + } + } + + for (auto expr : loop_indexing.getForwardExprList()) { + compute.handle(expr); + } + + return compute.preferred_path_; + } + + private: + void handle(Expr* e) override { + // If an input ID is marked, propagate the marking to outputs of the + // expression + auto all_iter_inputs = ir_utils::filterByType(e->inputs()); + if (std::any_of( + all_iter_inputs.begin(), + all_iter_inputs.end(), + [&](IterDomain* inp_id) { + return this->preferred_path_.find(ir_utils::caMapExactConcreteId( + inp_id)) != this->preferred_path_.end(); + })) { + auto all_iter_outputs = ir_utils::filterByType(e->outputs()); + + std::transform( + all_iter_outputs.begin(), + all_iter_outputs.end(), + std::inserter(preferred_path_, preferred_path_.end()), + ir_utils::caMapExactConcreteId); + } + } + + std::unordered_set preferred_path_; +}; + +} // namespace + +// External interface for preferred path propagation. +std::unordered_set buildLoopIndexingPreferredPath( + const TensorView* original_tv, + const LoopIndexing& loop_indexing, + bool use_replay_map, + std::unordered_map p2c_map) { + return LoopIndexingPreferredPathCompute::compute( + original_tv, loop_indexing, use_replay_map, p2c_map); +} + +// Get an rfactor IterDomain that is mapped with an IterDomain. If +// multiple such IDs exist, select one whose input IDs are mapped with +// the consumer IDs. This is to ensure the path from the leaf +// IterDomains to the root matches with the consumer tensor. +IterDomain* getRfactorIDToTraverse( + IterDomain* id, + const std::vector& consumer_all_ids) { + const auto& rfactor_ids = + GpuLower::current()->caMap()->getViewRfactorDomainsOfIdGroup( + id, IdMappingMode::PERMISSIVE); + + if (rfactor_ids.empty()) { + return nullptr; + } + + for (auto rfactor_id : rfactor_ids) { + auto def = rfactor_id->definition(); + if (def == nullptr) { + continue; + } + + auto rfactor_id_inputs = ir_utils::filterByType(def->inputs()); + if (std::all_of( + rfactor_id_inputs.begin(), + rfactor_id_inputs.end(), + [&](IterDomain* rfactor_id_input) { + return isPermissivelyMappedWithAny( + rfactor_id_input, consumer_all_ids); + })) { + return rfactor_id; + } + } + + // No mapped ID found, which means the consumer is a post-view + // tensor. In that case, it shouldn't matter which view path to + // traverse, so just return the first one. + return rfactor_ids.at(0); +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_index_compute.h b/torch/csrc/jit/codegen/cuda/lower_index_compute.h index a10931e925964..2a82ec2a5cfda 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_compute.h +++ b/torch/csrc/jit/codegen/cuda/lower_index_compute.h @@ -36,6 +36,17 @@ IndexFromIdGraph getTensorIndexFromIdGraph( bool is_global = true, std::unordered_map c2p_map = {}); +//! Indexing interface for calculating predicate index returns IndexFromIdGraph +//! which the IndexCompute object can be queried from directly for the produced +//! indexing If is_start_predicate, will produce indexing math for the start +//! predicates. +IndexFromIdGraph getPredicateIndexingFromIdGraph( + const std::vector& loops, + TensorView* consumer_tv, + kir::ForLoop* unswitch_or_vec_loop, + IterDomain* double_buffer_axis, + bool is_start_predicate); + //! getTensorIndexFromIdGraph is the function that index_compute will call very //! straightforwardly. However, for implementing the new indexing logic that //! starts to abstract some of the indexing away from index_compute we need to @@ -142,6 +153,24 @@ class LoopIndexing { std::vector index_exprs_; }; +// When indexing there are sometimes an option to propagate an index down +// multiple paths. This will return the IterDomains in the history of the +// reference domain and mark which paths should be taken (if there's a +// preference) to reach the roots provided in preferred_roots. +std::unordered_set buildLoopIndexingPreferredPath( + const TensorView* original_tv, + const LoopIndexing& loop_indexing, + bool use_replay_map = false, + std::unordered_map p2c_map = {}); + +// Get an rfactor IterDomain that is mapped with an IterDomain. If +// multiple such IDs exist, select one whose input IDs are mapped with +// the consumer IDs. This is to ensure the path from the leaf +// IterDomains to the root matches with the consumer tensor. +IterDomain* getRfactorIDToTraverse( + IterDomain* id, + const std::vector& consumer_all_ids); + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp b/torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp index 3eb6080db82e3..f0c342d921a93 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp @@ -229,7 +229,7 @@ std::pair CommonIndexMap::insert( const std::vector& loops, Val* index) { if (index->definition() == nullptr) { - // Only expression is eligible to hoist + // Only defined val is eligible to hoist return {index, false}; } @@ -246,7 +246,7 @@ std::pair CommonIndexMap::insert( const std::vector& loops, Val* index) { if (index->definition() == nullptr) { - // Only expression is eligible to hoist + // Only defined val is eligible to hoist return {index, false}; } @@ -262,6 +262,15 @@ std::pair CommonIndexMap::tryInsertNewIndex( Val* hoisted_index = nullptr; bool new_index_inserted = false; + // Hoisting is not possible if any of used loops is grouped. + if (std::any_of( + key.usedLoops().begin(), key.usedLoops().end(), [](const auto loop) { + return loop->iter_domain()->getParallelType() == + ParallelType::Group; + })) { + return {index, false}; + } + // If already mapped, return the previously mapped index auto it = common_index_map_.find(key); if (it != common_index_map_.end()) { diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 1315b0fc93b9b..2005ee751d66b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -601,7 +601,7 @@ class PredicateChcker : public IterVisitor { } void handle(GroupedReductionOp* grouped_rop) final { - for (const auto i : c10::irange(grouped_rop->numReductions())) { + for (const auto i : c10::irange(grouped_rop->numExprs())) { auto input = grouped_rop->input(i)->as(); auto input_def = input->definition(); // When input_def is null, input must be an input to the fusion, diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp index 17eb863486a90..fe1e0cc509c13 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_shift.cpp @@ -445,6 +445,11 @@ void HaloInfo::build(TensorDomain* td) { } else { setHaloWidth(merge->out(), 0); } + } else if (expr->getExprType().value() == ExprType::Swizzle2D) { + // Assume no halo on swizzled domain for now. + for (auto id : ir_utils::filterByType(expr->outputs())) { + setHaloWidth(id, 0); + } } else { TORCH_INTERNAL_ASSERT(false, "Unsupported expr: ", expr); } @@ -798,6 +803,20 @@ std::unordered_map HaloInfo::buildConcreteHaloExtentMap( } else { setHaloWidth(ir_utils::caMapExactConcreteId(merge->out()), 0); } + } else if (auto swizzle_2d = dynamic_cast(expr)) { + // Swizzle with halo not yet supported, just set the width + // to zero at the moment. + TORCH_INTERNAL_ASSERT( + local_halo_info.getHaloWidth( + ir_utils::caMapExactConcreteId(swizzle_2d->inX())) == 0 && + local_halo_info.getHaloWidth( + ir_utils::caMapExactConcreteId(swizzle_2d->inY())) == 0, + "Swizzle on ID with halo not yet supported."); + TORCH_INTERNAL_ASSERT("Swizzle on ID with halo not yet supported."); + local_halo_info.setHaloWidth( + ir_utils::caMapExactConcreteId(swizzle_2d->outX()), 0); + local_halo_info.setHaloWidth( + ir_utils::caMapExactConcreteId(swizzle_2d->outY()), 0); } else { TORCH_INTERNAL_ASSERT(false, "Unsupported expr: ", expr); } diff --git a/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp b/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp index 5f3eebceb3033..d29f9f2677dff 100644 --- a/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp @@ -339,6 +339,20 @@ void SyncMap::build(Fusion* fusion) { } } + // If any leaf id of producer is block or grid parallel and is + // involved + // in any swizzle pattern, track this parallel dim as a communication + // dimension that requires the corresponding synchronization and + // memory type. + if (isParallelTypeThread(producer_ptype) && + producer->hasSwizzleOp()) { + if (!ir_utils::getAllSwizzlesBetween( + producer->getMaybeRFactorDomain(), {p_id}) + .empty()) { + raw_dims.set(producer_ptype); + } + } + // In shift or gather operations, if a thread or block // domain's root ID is shifted or gathered, it can overlap // in shared or global memory. This doesn't diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index c1e948d45901d..6c4b96d1b394f 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -414,6 +414,26 @@ IterDomain* caMapExactConcreteId(IterDomain* id) { id, IdMappingMode::EXACT); } +std::vector getAllSwizzlesBetween( + std::vector from, + std::vector to) { + auto all_expr = DependencyCheck::getAllExprsBetween( + {from.begin(), from.end()}, {to.begin(), to.end()}); + + std::vector all_swizzles; + + std::copy_if( + all_expr.begin(), + all_expr.end(), + std::back_inserter(all_swizzles), + [](Expr* expr) { + return expr->getExprType().has_value() && + (expr->etype() == ExprType::Swizzle2D); + }); + + return all_swizzles; +} + } // namespace ir_utils namespace loop_utils { diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index 53cad32c04731..c6e48f5d503a0 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -47,6 +47,13 @@ class TVDomainGuard { public: explicit TVDomainGuard(TensorView* _tv, TensorDomain* td); + //! An utility to access the tensordomain before the temporary + //! view. This is used to retrieve information, like swizzle + //! information that can only be reliably kept at the original domain. + const TensorDomain* prevDomain() const { + return prev_domain; + } + ~TVDomainGuard(); }; @@ -64,13 +71,13 @@ std::vector iterDomainInputsOfOrderedAs( const std::vector& order); // Returns if Val is a TensorView or TensorIndex -bool isTV(const Val* const); +TORCH_CUDA_CU_API bool isTV(const Val* const); // Returns if Expr is a TensorView or TensorIndex Expr. TORCH_CUDA_CU_API bool isTvOp(const Expr*); // Returns the first output of Expr that is a TensorView -TensorView* getTvOutput(const Expr*); +TORCH_CUDA_CU_API TensorView* getTvOutput(const Expr*); // Returns if Expr is a reduction op TORCH_CUDA_CU_API bool isReductionOp(const Expr*); @@ -150,6 +157,12 @@ TORCH_CUDA_CU_API std::vector flattenScopedExprs( //! the exact compute at map. IterDomain* caMapExactConcreteId(IterDomain* id); +//! Returns all swizzle ops between the set of iterdomains +//! in `from` and `to`. +std::vector getAllSwizzlesBetween( + std::vector from, + std::vector to); + } // namespace ir_utils namespace loop_utils { diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 6f1c266c385ab..bfdb44402dfce 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -1037,6 +1037,152 @@ void validateMma(Fusion* fusion) { } } +void validateSwizzle(Fusion* fusion) { + auto used_vals = fusion->usedMathVals(); + for (auto tv : ir_utils::filterByType(used_vals)) { + if (tv->hasSwizzleOp()) { + // Make sure no swizzle op is inlined: + auto inlined_swizzles = ir_utils::getAllSwizzlesBetween( + tv->getMaybeRFactorDomain(), + {tv->domain()->domain().begin(), + tv->domain()->domain().begin() + tv->getComputeAtPosition()}); + TORCH_INTERNAL_ASSERT( + inlined_swizzles.empty(), "No support for inlined swizzles"); + } + } +} + +void validateAndConvertIterDomainGrouping(Fusion* fusion) { + for (auto tv : ir_utils::allTvs(fusion)) { + bool is_grouped = false; + for (const auto id_idx : c10::irange(tv->nDims())) { + const auto id = tv->axis(id_idx); + auto ptype = GpuLower::current() + ->caMap() + ->getConcreteMappedID(id, IdMappingMode::LOOP) + ->getParallelType(); + if (ptype != ParallelType::Group) { + // Not a grouped ID + continue; + } + + // Remember if a grouped ID is found + is_grouped = true; + + // Grouping only makes sense for the normal iteration type + TORCH_CHECK( + id->getIterType() == IterType::Iteration, + "Invalid use of ParallelType::Group.", + " Grouping of ", + id->getIterType(), + " is not allowed. ", + tv->toString()); + + // Extent must be static + TORCH_CHECK( + id->extent()->getInt().has_value(), + "Invalid use of ParallelType::Group.", + " IterDomain must have a static extent: ", + id->toString()); + + // The CA position must be left of any grouped ID + TORCH_CHECK( + tv->getComputeAtPosition() <= id_idx, + "Invalid use of ParallelType::Group.", + " ComputeAt position must be left of grouped IDs: ", + tv->toString()); + + // Similarly, the produce-at position must be left of any grouped ID + TORCH_CHECK( + tv->getMaxProducerPosition() <= id_idx, + "Invalid use of ParallelType::Group.", + " ProduceAt position must be left of grouped IDs: ", + tv->toString()); + + // Halo is not allowed + TORCH_CHECK( + GpuLower::current()->haloInfo().getExtent(id) == nullptr, + "Invalid use of ParallelType::Group.", + " Grouping of halo-extended IterDomain, ", + id->toString(), + ", is not supported. ", + tv->toString()); + } + + if (!is_grouped) { + continue; + } + + // Must be defined by ReductionOp + auto def = tv->definition(); + TORCH_CHECK( + def != nullptr, + "Invalid use of ParallelType::Group.", + " Definition of tv with ParallelType::Group not found. ", + tv->toString()); + + TORCH_CHECK( + tv->definition()->isA() || + tv->definition()->isA(), + "Invalid use of ParallelType::Group. Only ReductionOp and GroupedReductionOp are allowed. ", + tv->definition()->toString()); + + // Convert ReductionOp to GroupedReductionOp + if (tv->definition()->isA()) { + auto rop = def->as(); + auto is_allreduce = rop->isAllreduce(); + + TORCH_CHECK( + is_allreduce, + "Invalid use of ParallelType::Group.", + " Only enabled for allreduce reductions: ", + rop->toString()); + + TORCH_CHECK( + tv->domain()->hasGridReduction(), + "Invalid use of ParallelType::Group.", + " Only enabled for grid reductions: ", + rop->toString()); + + std::vector op_types({rop->getReductionOpType()}); + std::vector init_vals({rop->init()}); + std::vector outputs({rop->out()}); + std::vector inputs({rop->in()}); + + fusion->removeExpr(rop); + IrBuilder::create( + static_cast(fusion), + op_types, + init_vals, + outputs, + inputs, + is_allreduce); + } + } +} + +void validateGroupedReductions(Fusion* fusion) { + for (auto expr : StmtSort::getExprs(fusion)) { + if (auto grouped_reduction_op = dynamic_cast(expr)) { + const auto num_exprs = grouped_reduction_op->numExprs(); + int num_grouped_iterations = 1; + auto out_tv = ir_utils::getTvOutput(grouped_reduction_op); + for (auto axis : out_tv->domain()->domain()) { + if (axis->getParallelType() == ParallelType::Group) { + num_grouped_iterations *= axis->extent()->getInt().value(); + } + } + TORCH_CHECK( + num_exprs * num_grouped_iterations <= kMaxNumGroupedReductions, + "Too many grouped reductions: ", + grouped_reduction_op->toString(), + ". Up to ", + kMaxNumGroupedReductions, + " reductions are allowed."); + } + } +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.h b/torch/csrc/jit/codegen/cuda/lower_validation.h index d8c95d8d1f057..47305ac25ef4b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.h +++ b/torch/csrc/jit/codegen/cuda/lower_validation.h @@ -44,6 +44,32 @@ void validatePartialSplit(Fusion* fusion); //! mma operators on the fusion. void validateMma(Fusion* fusion); +//! Validates swizzle ops to ensure consistent indexing: +//! - Currently only allow swizzle ops on the right of CA axis, +//! - (Except ZShape) All swizzle ops have to be on const sized ids +//! - Xor and Transpose swizzle have to have equal dimensions on the +//! participating ids. +void validateSwizzle(Fusion* fusion); + +//! Validate use of ParallelType::Group. It is currently only allowed +//! in ReductionOp and not in WelfordOp. Group has similar constraints +//! as Vectorize, e.g., it can only be used with IterDomains with +//! static extents. Differences are, e.g, it has no constraints on +//! alignments and predicates. Each individual reduction has its own +//! predicate, so it is possile for only part of grouped reductions to +//! be executed. +//! +//! Also, grouping is only enabled for persistent grid reductions, in +//! other words, grid allreduces. Note that no grid reduction without +//! broadcast is persistent anymore. +//! +//! Validated ReductionOp with ParallelType::Group is converted to +//! GroupedReductionOp. +void validateAndConvertIterDomainGrouping(Fusion* fusion); + +//! Validate the number of grouped reductions is within the limit +void validateGroupedReductions(Fusion* fusion); + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp index cf026ad7f4477..4117ee0d6b6b6 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp @@ -66,18 +66,25 @@ void MaxInfoSpanningTree::compute_spanning_tree() { } }; - auto allowPasC = [this](TensorView* from, TensorView* to) { + auto allowC2P = [this](TensorView* from, TensorView* to) { if (selector_ == nullptr) { return true; } - return selector_->allowPasC(from, to); + return selector_->allowC2P(from, to); }; - auto allowCasP = [this](TensorView* from, TensorView* to) { + auto allowP2C = [this](TensorView* from, TensorView* to) { if (selector_ == nullptr) { return true; } - return selector_->allowCasP(from, to); + return selector_->allowP2C(from, to); + }; + + auto allowSibling = [this](TensorView* from, TensorView* to) { + if (selector_ == nullptr) { + return true; + } + return selector_->allowSibling(from, to); }; while (!candidates.empty()) { @@ -91,8 +98,19 @@ void MaxInfoSpanningTree::compute_spanning_tree() { } replayed.emplace(next_hop.to); + for (auto sibling_tv : ir_utils::siblingTvsOf(next_hop.to)) { + if (replayed.count(sibling_tv) || + !allowSibling(next_hop.to, sibling_tv)) { + continue; + } + insertNextHop(NextHopWithInfo( + NextHop(NextHopType::SIBLING, next_hop.to, sibling_tv), + next_hop_info.info_to, + computeInfoSibling(next_hop.to, sibling_tv, next_hop_info.info_to))); + } + for (auto consumer_tv : ir_utils::consumerTvsOf(next_hop.to)) { - if (replayed.count(consumer_tv) || !allowCasP(next_hop.to, consumer_tv)) { + if (replayed.count(consumer_tv) || !allowP2C(next_hop.to, consumer_tv)) { continue; } insertNextHop(NextHopWithInfo( @@ -102,7 +120,7 @@ void MaxInfoSpanningTree::compute_spanning_tree() { } for (auto producer_tv : ir_utils::producerTvsOf(next_hop.to)) { - if (replayed.count(producer_tv) || !allowPasC(next_hop.to, producer_tv)) { + if (replayed.count(producer_tv) || !allowC2P(next_hop.to, producer_tv)) { continue; } insertNextHop(NextHopWithInfo( @@ -119,11 +137,14 @@ void MaxInfoSpanningTree::traverse(Propagator* propagator) { } for (const auto& next_hop : path_) { switch (next_hop.type) { + case NextHopType::SIBLING: + propagator->propagateSibling(next_hop.from, next_hop.to); + break; case NextHopType::C_AS_P: - propagator->propagateTvCasP(next_hop.from, next_hop.to); + propagator->propagateP2C(next_hop.from, next_hop.to); break; case NextHopType::P_AS_C: - propagator->propagateTvPasC(next_hop.from, next_hop.to); + propagator->propagateC2P(next_hop.from, next_hop.to); break; } } @@ -372,6 +393,35 @@ MaxRootDomainInfoSpanningTree::getReferenceRootIDInfo( return std::make_shared(std::move(result)); } +// Given the preserved reference root ID info of a tensor, compute +// the corresponding info in its sibling. Since info has nothing to do with +// replay state, so sibling info is always identical by definition. +std::shared_ptr MaxRootDomainInfoSpanningTree:: + computeInfoSibling( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) const { + return from_info; +} + +void SpanningTreePrinter::propagateC2P(TensorView* from, TensorView* to) { + stream_ << "propagateC2P" << std::endl; + stream_ << " from: " << from->toString() << std::endl; + stream_ << " to: " << to->toString() << std::endl; +} + +void SpanningTreePrinter::propagateP2C(TensorView* from, TensorView* to) { + stream_ << "propagateP2C" << std::endl; + stream_ << " from: " << from->toString() << std::endl; + stream_ << " to: " << to->toString() << std::endl; +} + +void SpanningTreePrinter::propagateSibling(TensorView* from, TensorView* to) { + stream_ << "propagateSibling" << std::endl; + stream_ << " from: " << from->toString() << std::endl; + stream_ << " to: " << to->toString() << std::endl; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h index 7d0665d148ea5..231f84bed897c 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h @@ -29,8 +29,9 @@ namespace cuda { * MaxInfoSpanningTree::Information and implement `operator<` which is used to * tell which path contains more information, and `operator bool` which is used * to tell if there is any information stored. You also need to implement - * computeInfoPasC and computeInfoCasP, which are the functions that compute - * information of the `to` tensor from the information of the `from` tensor. + * computeInfoPasC, computeInfoCasP, and computeInfoSibling, which are the + * functions that compute information of the `to` tensor from the information of + * the `from` tensor. */ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class TORCH_CUDA_CU_API MaxInfoSpanningTree { @@ -38,15 +39,17 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree { // Class to subclass in order to stop traversal, by which limits the nodes in // the spanning tree. struct Selector { - virtual bool allowPasC(TensorView* from, TensorView* to) = 0; - virtual bool allowCasP(TensorView* from, TensorView* to) = 0; + virtual bool allowC2P(TensorView* from, TensorView* to) = 0; + virtual bool allowP2C(TensorView* from, TensorView* to) = 0; + virtual bool allowSibling(TensorView* from, TensorView* to) = 0; virtual ~Selector() {} }; // This is the interface to implement the actual propagation struct Propagator { - virtual void propagateTvPasC(TensorView* from, TensorView* to) = 0; - virtual void propagateTvCasP(TensorView* from, TensorView* to) = 0; + virtual void propagateC2P(TensorView* from, TensorView* to) = 0; + virtual void propagateP2C(TensorView* from, TensorView* to) = 0; + virtual void propagateSibling(TensorView* from, TensorView* to) = 0; virtual ~Propagator() {} }; @@ -73,6 +76,7 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree { private: enum class NextHopType { + SIBLING, C_AS_P, P_AS_C, }; @@ -122,6 +126,10 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree { TensorView* from, TensorView* to, std::shared_ptr from_info) const = 0; + virtual std::shared_ptr computeInfoSibling( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) const = 0; public: MaxInfoSpanningTree( @@ -201,6 +209,10 @@ class TORCH_CUDA_CU_API MaxRootDomainInfoSpanningTree TensorView* from, TensorView* to, std::shared_ptr from_info) const override; + virtual std::shared_ptr computeInfoSibling( + TensorView* from, + TensorView* to, + std::shared_ptr from_info) const override; private: static std::shared_ptr getReferenceRootIDInfo(TensorView* tv); @@ -231,6 +243,17 @@ class TORCH_CUDA_CU_API MaxRootDomainInfoSpanningTree selector) {} }; +class TORCH_CUDA_CU_API SpanningTreePrinter + : public MaxInfoSpanningTree::Propagator { + std::ostream& stream_; + + public: + virtual void propagateC2P(TensorView* from, TensorView* to) override; + virtual void propagateP2C(TensorView* from, TensorView* to) override; + virtual void propagateSibling(TensorView* from, TensorView* to) override; + SpanningTreePrinter(std::ostream& stream = std::cout) : stream_(stream) {} +}; + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/mma_type.cpp b/torch/csrc/jit/codegen/cuda/mma_type.cpp index 9ac067eb8b77d..60f3a0b95d4c6 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.cpp +++ b/torch/csrc/jit/codegen/cuda/mma_type.cpp @@ -40,6 +40,9 @@ MmaBuilder& MmaBuilder::operand(MmaOptions::Operand a_or_b) { // TODO: validate op config MmaOptions MmaBuilder::build() const { + TORCH_CHECK( + option_.mma_op != nullptr, + "Please configure accumulator tv before using swizzle options.") return option_; } @@ -53,6 +56,15 @@ void MmaBuilder::configureMma(TensorView* mma_output) const { mma->configureOptions(option_); } +void MmaBuilder::accumulatorTv(TensorView* tv) { + TORCH_CHECK( + tv->getMemoryType() == MemoryType::Local, "Mma only outputs to register"); + TORCH_CHECK(tv->definition(), "Input cannot be accumulator tv"); + auto mma = dynamic_cast(tv->definition()); + TORCH_CHECK(mma, "Requires mma op output for reduction tv"); + option_.mma_op = mma; +} + namespace { // Utility to get ldmatrix direction a mma layout and operand diff --git a/torch/csrc/jit/codegen/cuda/mma_type.h b/torch/csrc/jit/codegen/cuda/mma_type.h index 6b94d74a4f5b8..a004d1f1450c0 100644 --- a/torch/csrc/jit/codegen/cuda/mma_type.h +++ b/torch/csrc/jit/codegen/cuda/mma_type.h @@ -94,6 +94,9 @@ struct MmaOptions { operand == other.operand && accumulator_stride == other.accumulator_stride; } + + // To be inferred by mma builder interface. + MmaOp* mma_op = nullptr; }; //! User interface for configuring the mma and mma related @@ -127,6 +130,10 @@ class TORCH_CUDA_CU_API MmaBuilder { //! specified mma option. LoadStoreOpType ldMatrix() const; + //! Store the accumulator tv register reference in mma builder + //! to avoid automatic matching of which mma ops. + void accumulatorTv(TensorView* tv); + //! Fill in mma options in scheduling time. //! Each mma op in Fusion IR must be configured once before lowering. //! Mma options are configuration parameters used in lowering to mma diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index e828f2fa866df..928b20be64ca9 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -445,6 +445,26 @@ void OptOutMutator::mutate(Merge* m) { C10_UNUSED auto new_node = IrBuilder::create(container, ot, otr, in); } +void OptOutMutator::mutate(Swizzle2D* m) { + IterDomain* outx = maybeMutated(m->outX())->as(); + IterDomain* outy = maybeMutated(m->outY())->as(); + + IterDomain* inx = maybeMutated(m->inX())->as(); + IterDomain* iny = maybeMutated(m->inY())->as(); + + auto swizzle_type = m->swizzleType(); + + if (outx->sameAs(m->outX()) && outy->sameAs(m->outY()) && + inx->sameAs(m->inX()) && iny->sameAs(m->inY())) { + return; + } + auto container = m->container(); + container->removeExpr(m); + FusionGuard::getCurFusion()->removeExpr(m); + C10_UNUSED auto new_node = IrBuilder::create( + container, outx, outy, inx, iny, swizzle_type); +} + void OptOutMutator::mutate(kir::Allocate*) { TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } @@ -484,6 +504,15 @@ void OptOutMutator::mutate(kir::GridWelford*) { void OptOutMutator::mutate(kir::AllocateFusedReduction*) { TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } +void OptOutMutator::mutate(kir::Swizzle2DInt*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::PairSelect*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} +void OptOutMutator::mutate(kir::IntPair*) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); +} void OptOutMutator::removeExpr(IrContainer* container, Expr* expr) { container->removeExpr(expr); diff --git a/torch/csrc/jit/codegen/cuda/nvfuser.cmake b/torch/csrc/jit/codegen/cuda/nvfuser.cmake index 4f5d6a0a664e4..05c36a90d499d 100644 --- a/torch/csrc/jit/codegen/cuda/nvfuser.cmake +++ b/torch/csrc/jit/codegen/cuda/nvfuser.cmake @@ -29,6 +29,7 @@ list(APPEND NVFUSER_RUNTIME_FILES ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/warp.cu ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/tensorcore.cu ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/memory.cu + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/swizzle.cu ${TORCH_ROOT}/aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh ${TORCH_ROOT}/aten/src/ATen/cuda/detail/UnpackRaw.cuh ) diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index aba2a36230c81..ead0031d8358d 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -74,6 +74,32 @@ const auto& profileFailedAttr = Symbol::attr("profile_failed"); typedef Val* CgValue; typedef Expr* CgOp; +Val* castTensoToDtype(CgValue self, JitValue* cast_val) { + auto cast_ival = toIValue(cast_val); + // we need static type for cast + TORCH_INTERNAL_ASSERT(cast_ival.has_value()); + if (cast_ival->isInt()) { + auto dtype = cast_ival->toScalarType(); + + // We want to keep our internal fusion math in FP32 + // Shape Inference will continue to propagate the right + // type to outputs unchanged. + if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16) { + dtype = at::ScalarType::Float; + } + + return castOp(aten_to_data_type(dtype), self); + } else { + TORCH_INTERNAL_ASSERT( + cast_ival->isNone(), + "unrecognized dtype option, expect 'int' but got: ", + cast_ival->tagKind()); + + // return a copy if dtype is `None` + return set(self); + } +} + bool isReductionNonCompatibleTensor( const std::shared_ptr& tensor_type) { return is_zero_dim_tensor(tensor_type) || is_zero_sized_tensor(tensor_type); @@ -2706,10 +2732,9 @@ class IrParser { } } - // Limiting aten::to implementation to only change the dtype of a tensor { auto ptr_op = getOperatorForLiteral( - "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor"); + "aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { @@ -2720,22 +2745,59 @@ class IrParser { auto self = list_val.front(); list_val.pop_front(); - // we need static type for cast - TORCH_INTERNAL_ASSERT( - node->input(1)->node()->kind() == prim::Constant); - auto dtype = toIValue(node->input(1))->toScalarType(); - - // We want to keep our internal fusion math in FP32 - // Shape Inference will continue to propagate the right - // type to outputs unchanged. - if (dtype == at::ScalarType::Half) { - dtype = at::ScalarType::Float; + auto out = castTensoToDtype(self, node->input(1)); + + value_map.emplace( + node->output()->unique(), ValueHolder(out, format)); + }, + [](const Node* node) -> bool { + if (!isInputNonSizeZeroTensor(node)) { + return false; + } + if (node->inputs()[1]->node()->kind() != prim::Constant) { + return false; + } + // we do not support explicit memory_format on output + if (!node->inputs()[2]->type()->isSubtypeOf( + static_cast(NoneType::get()))) { + return false; + } + // we do not support explicit memory_format on output + if (!node->inputs()[3]->type()->isSubtypeOf( + static_cast(NoneType::get()))) { + return false; + } + // we do not support explicit memory_format on output + if (!node->inputs()[4]->type()->isSubtypeOf( + static_cast(NoneType::get()))) { + return false; } - if (dtype == at::ScalarType::BFloat16) { - dtype = at::ScalarType::Float; + // we do not support explicit memory_format on output + if (!node->inputs()[6]->type()->isSubtypeOf( + static_cast(NoneType::get()))) { + return false; } + return true; + }, + nullptr); + } + + // Limiting aten::to implementation to only change the dtype of a tensor + { + auto ptr_op = getOperatorForLiteral( + "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor"); + REGISTER_PARSE_RULE( + ptr_op, + { + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getConsistentValues( + c10::nullopt, value_map[node->inputs()[0]->unique()]); + auto self = list_val.front(); + list_val.pop_front(); + + auto out = castTensoToDtype(self, node->input(1)); - auto out = castOp(aten_to_data_type(dtype), self); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, @@ -4188,6 +4250,20 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { return true; } + static auto to_copy_schema = + getOperatorForLiteral( + "aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor") + ->schema(); + if (node->matches(to_copy_schema)) { + switch (offset) { + case 1: + profileInt(pr, node, offset); + return true; + default: + return false; + } + } + static auto to_dtype_schema = getOperatorForLiteral( "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor") diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp index 965677154c5e2..2941b96fdae10 100644 --- a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp @@ -354,10 +354,8 @@ Bool* PredicateCompute::getInlinePredicate( ->as(); } - auto pred_info_vec = - Index::getReferenceRootPredicates( - out_tv, loops, nullptr, pred_type == PredicateType::Padding) - .first; + auto pred_info_vec = Index::getReferenceRootPredicates( + out_tv, loops, nullptr, pred_type == PredicateType::Padding); std::vector preds; @@ -466,7 +464,7 @@ void UnswitchPredicate::predicateOn(Expr* tv_expr) { // temporarily placed in the predicated_keys map and the final // predicates are generated in the finalize function. - for (const auto& pred_info : ref_pred_info.first) { + for (const auto& pred_info : ref_pred_info) { TORCH_INTERNAL_ASSERT(pred_info.startPredicate() != nullptr); TORCH_INTERNAL_ASSERT(pred_info.stopPredicate() != nullptr); diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_broadcast_in_dim.py b/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_broadcast_in_dim.py index aa2fb2016de8d..dc0390521d153 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_broadcast_in_dim.py +++ b/torch/csrc/jit/codegen/cuda/python_frontend/examples/python_example_broadcast_in_dim.py @@ -1,6 +1,8 @@ import torch from torch._C._nvfuser import Fusion, FusionDefinition +import torch._prims as prims +import torch._refs as refs # Construct and Define Fusion fusion1 = Fusion() @@ -20,20 +22,25 @@ fusion1.print_ir() # Execute Fusion -input1 = torch.ones(3, device='cuda') -input2 = torch.ones(2, 3, 4, device='cuda') +input1 = torch.randn(3, device='cuda') +input2 = torch.randn(2, 3, 4, device='cuda') # Kernel compilation should be cached for the 2nd iteration # with input tensors of the same shape for _ in range(5) : - outputs = fusion1.execute([input1, input2]) + o = fusion1.execute([input1, input2])[0] -print(outputs[0]) +assert(o.shape == torch.Size([2, 3, 4])) + +# Reference in prim torch +ref_o = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [1]), input2) +assert(ref_o.allclose(o)) +assert(ref_o.shape == o.shape) fusion2 = Fusion() -input1 = torch.ones(1, 1, 4, device='cuda') -input2 = torch.ones(2, 3, 4, device='cuda') +input1 = torch.randn(1, 1, 4, device='cuda') +input2 = torch.randn(2, 3, 4, device='cuda') with FusionDefinition(fusion2) as fd : t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride()) @@ -43,7 +50,6 @@ fd.add_input(t1) t0_b = fd.Ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2]) - print("Broadcast TensorView", t0_b) t2 = fd.Ops.add(t0_b, t1) fd.add_output(t2) @@ -53,6 +59,45 @@ # Kernel compilation should be cached for the 2nd iteration # with input tensors of the same shape for _ in range(5) : - outputs = fusion2.execute([input1, input2]) + o = fusion2.execute([input1, input2])[0] + +assert(o.shape == torch.Size([2, 3, 4])) + +# Reference in prim torch +ref_o = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [0, 1, 2]), input2) +assert(ref_o.allclose(o)) +assert(ref_o.shape == o.shape) + +# Construct and Define Fusion +fusion3 = Fusion() + +with FusionDefinition(fusion3) as fd : + # t0 = fd.define_tensor(2) + t0 = fd.define_tensor([3, 1], [1, 1]) + t1 = fd.define_tensor(1) + + fd.add_input(t0) + fd.add_input(t1) + + t1_b = fd.Ops.broadcast_in_dim(t1, [3, 3], [0]) # 1 -> 0 + t2 = fd.Ops.add(t0, t1_b) + + fd.add_output(t2) + +fusion3.print_ir() + +# Execute Fusion +input1 = torch.randn(3, 1, device='cuda') +input2 = torch.randn(3, device='cuda') + +# Kernel compilation should be cached for the 2nd iteration +# with input tensors of the same shape +for _ in range(5) : + o = fusion3.execute([input1, input2])[0] + +assert(o.shape == torch.Size([3, 3])) -print(outputs[0]) +# Reference in prim torch +ref_o = refs.add(input1, prims.broadcast_in_dim(input2, [3, 3], [0])) +assert(ref_o.allclose(o)) +assert(ref_o.shape == o.shape) diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp index fd7fc320e77a7..a702479bae69e 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp +++ b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp @@ -619,7 +619,8 @@ void initNvFuserPythonBindings(PyObject* module) { [](TensorView* input, std::vector& output_shape, std::vector& broadcast_dims) -> TensorView* { - const auto input_ndims = input->domain()->noReductions().size(); + const auto& iter_domains = input->domain()->noReductions(); + const auto input_ndims = iter_domains.size(); TORCH_CHECK( output_shape.size() >= input_ndims, "The new shape is expected to be greater-then-or-equal to the input", @@ -631,7 +632,9 @@ void initNvFuserPythonBindings(PyObject* module) { input_ndims, broadcast_dims.size()); + // default all dimensions to be broadcasted std::vector is_broadcast_dim(output_shape.size(), true); + std::vector is_expand_dim(output_shape.size(), true); for (const auto idx : c10::irange(broadcast_dims.size())) { if (idx > 0) { TORCH_CHECK( @@ -642,9 +645,30 @@ void initNvFuserPythonBindings(PyObject* module) { broadcast_dims[idx] < static_cast(output_shape.size()), "Invalid broadcast_dims value."); is_broadcast_dim.at(broadcast_dims[idx]) = false; + // Note: when we expand a broadcasted dimension, we need to expand it + // to a concrete size, hence the need for `is_expand_dim` flag and the + // expand operation following the broadcast. + is_expand_dim.at(broadcast_dims[idx]) = + iter_domains[idx]->isBroadcast(); } - return torch::jit::fuser::cuda::broadcast(input, is_broadcast_dim); + std::vector output_shape_on_bcast( + output_shape.size(), nullptr); + for (const auto idx : c10::irange(output_shape.size())) { + if (is_expand_dim[idx]) { + // TODO: this would be tricky to handle on dynamic shapes, we'll + // need to pass-in a symbol instead somehow. + output_shape_on_bcast[idx] = + IrBuilder::create(output_shape[idx]); + } else { + output_shape_on_bcast[idx] = IrBuilder::create(-1); + } + } + + auto bcasted_input = + torch::jit::fuser::cuda::broadcast(input, is_broadcast_dim); + return torch::jit::fuser::cuda::expand( + bcasted_input, output_shape_on_bcast); }, py::return_value_policy::reference); diff --git a/torch/csrc/jit/codegen/cuda/runtime/swizzle.cu b/torch/csrc/jit/codegen/cuda/runtime/swizzle.cu new file mode 100644 index 0000000000000..036f9eccd424b --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/runtime/swizzle.cu @@ -0,0 +1,72 @@ +// Utility macro for this file +#define DEVICE_INLINE __device__ inline + +// Utility class for 2D swizzle: +template +struct IndexGeneric { + const index_t x = 0, y = 0; + DEVICE_INLINE IndexGeneric(index_t x_, index_t y_) : x(x_), y(y_) {} +}; + +// Default type for integration +using Index2D = IndexGeneric; + +// Small type for unit computation +using Index2DInt = IndexGeneric; + +// ------------------------------------------------------------ +// Swizzle Definitions +// for each swizzle name: +// un(Swizzle Name) e.g. unZShape is the inverse of ZShape, +// (unswizzle is needed for inlining and is currently not actively used.) +// ------------------------------------------------------------ + +// Unit Z swizzle: +// Alternate directions of Y dimension: +// 1 2 3 1 2 3 +// 4 5 6 => 6 5 4 +// 7 8 9 7 8 9 +DEVICE_INLINE Index2D ZShape(Index2D in, Index2D unit_dim) { + return Index2D(in.x, in.x % 2 == 0 ? in.y : (unit_dim.y - in.y - 1)); +} + +// ZShape is inverse of itself +DEVICE_INLINE Index2D unZShape(Index2D in, Index2D unit_dim) { + return ZShape(in, unit_dim); +} + +// Block cyclic Xor swizzle: (bank conflict removal) +// Apply cyclic Xor within blocks: +// Example: cyclic Xor +// 1 2 3 4 1 2 3 4 +// 5 6 7 8 6 5 8 7 +// 9 10 11 12 => 11 12 9 10 +// 13 14 15 16 16 15 14 13 +// Note: +DEVICE_INLINE Index2D Xor(Index2D in, Index2DInt unit_dim) { + // Need to validate in swizzle configuration: + // unit_dim.x == unit_dim.y + return Index2D(in.x, (in.y ^ in.x)); +} + +// Inverse of Xor is itself +DEVICE_INLINE Index2D unXor(Index2D in, Index2DInt unit_dim) { + return Xor(in, unit_dim); +} + +// Scatter swizzle: +// Corresponds to the data layout out of ldmatrix intrinsic. +// supported dimensions are : 8x4, 16x4, 32x4 +template +DEVICE_INLINE Index2D Scatter(Index2D in) { + static_assert(row_size == 8 || row_size == 16 || row_size == 32); + return Index2D((in.y * row_size + in.x) / 4, in.x % 4); +} + +template +DEVICE_INLINE Index2D unScatter(Index2D in) { + static_assert(row_size == 8 || row_size == 16 || row_size == 32); + return Index2D(in.y + (in.x % (row_size / 4)) * 4, in.x / (row_size / 4)); +} + +#undef DEVICE_INLINE diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index 4eb8ff69e0484..35190ea6f0908 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include namespace torch { @@ -16,100 +17,6 @@ namespace { // Utility for mma dimension matching enum class MmaDimension { M = 0, N, K }; -// Utility for mma dimension matching before broadcast, -// assumes the innermost 2 dimensions are the mma -// operand dimensions, i.e. mnk. -IterDomain* getMmaOperandRootDimension2d( - TensorView* tv, - MmaOptions options, - MmaDimension mma_dimension) { - TORCH_INTERNAL_ASSERT(tv->getMaybeRFactorDomain().size() >= 2); - // NT : K,M x K,N -> K,M,N - // TT : M,K X K,N -> M,K,N - // TN : M,K X N,K -> M,N,K - int axis_id = mma_dimension == MmaDimension::K ? 1 : 0; - bool is_transposed = isOperandTransposed(options); - - // Decode the transpostion - if ((options.operand == MmaOptions::Operand::A && !is_transposed) || - (options.operand == MmaOptions::Operand::B && is_transposed)) { - axis_id = 1 - axis_id; - } - - int root_size = tv->getMaybeRFactorDomain().size(); - // Convert to index from right. - return tv->getMaybeRFactorDomain().at(root_size + axis_id - 2); -} - -// Utility for mma dimension matching, assumes the innermost -// 3 dimensions are the mma operand dimensions, i.e. mnk, but -// not necessarily in this order. -// For matmul use cases the root domains are always 3 dimensional, -// but this wouldn't be the case for other kernels such as batched gemm. -// This utility only applies to the case where the innermost 3 dims -// are the one that mma's are used. We probably don't want to use -// mma intrinsics if that's not the case. -IterDomain* getMmaOperandRootDimension3d( - TensorView* tv, - MmaOptions::MmaInputLayout layout, - MmaDimension mma_dimension) { - TORCH_INTERNAL_ASSERT(tv->getMaybeRFactorDomain().size() >= 3); - // NT : K,M x K,N -> K,M,N - // TT : M,K X K,N -> M,K,N - // TN : M,K X N,K -> M,N,K - int axis_id = -1; - switch (mma_dimension) { - case MmaDimension::K: - axis_id = (int)layout; - break; - case MmaDimension::M: - axis_id = layout == MmaOptions::MmaInputLayout::NT ? 1 : 0; - break; - case MmaDimension::N: - axis_id = layout == MmaOptions::MmaInputLayout::TN ? 1 : 2; - break; - default: - TORCH_INTERNAL_ASSERT(false, "Unreachable"); - break; - } - - int root_size = tv->getMaybeRFactorDomain().size(); - // Convert to index from right. - return tv->getMaybeRFactorDomain().at(root_size + axis_id - 3); -} - -// Locate the root id corresponding to the given mma dimension -// Assumes the mma dimension always the innermost 2 or 3, might -// need to extend for more complex fusions. -IterDomain* getMmaOperandRootDimension( - TensorView* tv, - MmaOptions options, - MmaDimension mma_dimension) { - if (isVolta(options.macro)) { - return getMmaOperandRootDimension3d( - tv, options.operand_layout, mma_dimension); - } else if (isTuring(options.macro) || isAmpere(options.macro)) { - // Volta mma swizzle requires the broadcast dimension to - // participate, which is not true in Turing+. So the two - // cases w/ or w/o the broadcast are supported here for - // mma pre-swizzle validation. - bool has_broadcast_or_reduction = std::any_of( - tv->getMaybeRFactorDomain().begin(), - tv->getMaybeRFactorDomain().end(), - [](IterDomain* id) { return id->isBroadcast() || id->isReduction(); }); - if (has_broadcast_or_reduction) { - TORCH_INTERNAL_ASSERT(tv->nDims() >= 3); - return getMmaOperandRootDimension3d( - tv, options.operand_layout, mma_dimension); - } else { - TORCH_INTERNAL_ASSERT(tv->nDims() >= 2); - return getMmaOperandRootDimension2d(tv, options, mma_dimension); - } - } - TORCH_INTERNAL_ASSERT(false, "unreachable"); - return nullptr; -} - // Preliminary checks to try to validate that leaf is // a innermost dim of root of exactly the given size. bool canValidateIsInnerDim( @@ -259,39 +166,208 @@ void WarpMmaSwizzler::setWarpMapped(TensorView* tv, int number_of_dims) { namespace { -// Utility to check operand innermost scheduled dimensions -void validateInnerMNK(TensorView* tv, MmaOptions options, int m, int n, int k) { +// Utility function for mma domain mapping: +// returns the Iterdomain from the accumulator tv that corresponds +// to the given mma dimension. See [MMA dimension matching]. +std::vector getMmaDomains(MmaOp* mma, MmaDimension dimension) { + // This utility is user facing so shouldn't ever see tensor index here. + + // Note: [Use Root Domain in Accumulator TV] + // Have to use root domain for accumulator tv since the operands do not have + // root/rfactor domains that map to the rfactor domain of output. + // For example: + // C[I,I,R,R] = mma (A[I,B,I,I], B[B,I,I,I]), + // if we do + // c->split(-1,4); + // c->rfactor(-1); + // on the mma stage we get: + // C[I,I,R,Io,R(4)] = mma (A[I,B,I,I], B[B,I,I,I]), + // and in this case Io and R(4) would not be able to find root mapping + // in A or B. + // + // Essentially in the case of rfactor, this utility does producer side + // matching so looking at root domain would be required. + // This matching pattern should support most common matmul applications, + // but in follow ups we may need to extend RFactor matching if there + // are more complex scheduling patterns that we want to support. + auto accumulator_domain = mma->out()->as()->getRootDomain(); + auto a_domain = TensorDomain::noReductions( + mma->inA()->as()->getMaybeRFactorDomain()); + auto b_domain = TensorDomain::noReductions( + mma->inB()->as()->getMaybeRFactorDomain()); + TORCH_CHECK( + a_domain.size() == b_domain.size() && + a_domain.size() == accumulator_domain.size(), + "Inconsisitent dimensions in mma op", + a_domain.size(), + " ", + b_domain.size(), + " ", + accumulator_domain.size()); + + std::vector result; + + for (int id_idx : c10::irange(a_domain.size())) { + // checks if this id should be included in the result + bool include_this_id = false; + bool is_broadcast_in_a = a_domain[id_idx]->isBroadcast(); + bool is_broadcast_in_b = b_domain[id_idx]->isBroadcast(); + bool is_reduction_id = accumulator_domain[id_idx]->isReduction(); + + switch (dimension) { + case MmaDimension::K: + // K dimension is the dimension that is concrete in + // operands, and is reduced by mma. This complies with + // tensor contraction definition. + include_this_id = + !is_broadcast_in_a && !is_broadcast_in_b && is_reduction_id; + break; + // M and N dimension below are defined as the iterdomains + // that are not reduced by mma, and are concretized in this stage. + case MmaDimension::M: + include_this_id = + !is_broadcast_in_a && is_broadcast_in_b && !is_reduction_id; + break; + case MmaDimension::N: + include_this_id = + is_broadcast_in_a && !is_broadcast_in_b && !is_reduction_id; + break; + + default: + TORCH_INTERNAL_ASSERT(false, "unreachable"); + } + + if (include_this_id) { + result.push_back(accumulator_domain.at(id_idx)); + } + } + + return result; +} + +// [MMA dimension matching] +// Returns all the axes that correspond to the given mma dimension. This is the +// first relaxation step on the mma check. +// Mma operations concerns 3 dimensions, namely, the M, N, +// and K dimension, more details see [Operand Layout Convention] in mma_type.h. +// The current implementation, for best effort safety, supports the patterns +// where the root axes can be classified into one of the 3 dimension types. +// This is a helpful initial step into defining tensor contraction +// optimizations. +// +// A concrete example: +// T0 [I0, I1, I2, I3, I4, I5] = mma(T1[I01, B11, B21, I31, I41, B51], T2[B02, +// I12, B22, I32, I42, I52], {3}; +// In this case some example querries: +// K dimension of T0 = {I3} +// M dimension of T1 = {I01} +// N dimension of T2 = {I52} +// etc. +std::vector getMmaRootDimensions( + TensorView* tv, + MmaOp* mma, + MmaDimension dimension) { + // Build a fusion-level root domain map + // so we can use the mma swizzles on non-immediate tensor operands, for + // example loadstore staging ops. + ComputeAtRootDomainMap root_map; + root_map.build(); + + // FIXME: + // Several optimization is possible at this stage but assuming we don't have + // a lot of mma ops in a fusion this could be lower priority. + // First it'd be nice not having to build root map every time this function + // is called. That'd require some explicit boundary where we "lock" the + // compute in the fusion so the root map stays valid. + // Second it'd reduce complexity of the below matching by an order if we have + // something similar to "disjointSetOf" in idGraph, for just the root domains + // at scheduler composing time. + auto mma_root_dimensions = getMmaDomains(mma, dimension); + auto mma_accumulator_tv = mma->out()->as(); + + std::vector result; + + // Need to use root domain for accumulator tv and maybe rfactor domain + // otherwise. See [Use Root Domain in Accumulator TV]. + auto is_mma_output = + tv->definition() != nullptr && tv->definition()->isA(); + const auto& tv_root_domain = + is_mma_output ? tv->getRootDomain() : tv->getMaybeRFactorDomain(); + + // Loop through tensorview's root domains and accumulate all the + // root domain IterDomain's that maps to any of the collected + // mma root dimension from the mma accumulator tv. + for (auto tv_id : tv_root_domain) { + if (std::any_of( + mma_root_dimensions.begin(), + mma_root_dimensions.end(), + [&](IterDomain* mma_id) { + return root_map.canMap( + tv->domain(), tv_id, mma_accumulator_tv->domain(), mma_id); + })) { + result.push_back(tv_id); + } + } + + return result; +} + +//! Utility function to help check that the innermost 3 iterdomains +//! are also the corresponding innermost {m,n,k} dimensions of +//! the root id's that are participating in the mma operation. +//! This is a format check before the warp mma swizzler applies mma +//! swizzles to make sure that the swizzler is applying the right +//! swizzles to the right axes. +//! This check will be relaxed as we build out the mma usage patterns. +void validateMmaRootInnerMNK( + TensorView* tv, + MmaOptions options, + int m, + int n, + int k) { + auto m_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::M); + auto n_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::N); + auto k_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::K); + + TORCH_CHECK( + !m_dims.empty() && !n_dims.empty() && !k_dims.empty(), + "validateMmaRootInnerMNK: MMA Axes incomplete"); + + // Still check the innermost dims of each at the current state: TORCH_INTERNAL_ASSERT(tv->nDims() >= 3); TORCH_INTERNAL_ASSERT( - canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::M), - tv->axis(-3), - m), + canValidateIsInnerDim(m_dims.back(), tv->axis(-3), m), "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); TORCH_INTERNAL_ASSERT( - canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::N), - tv->axis(-2), - n), + canValidateIsInnerDim(n_dims.back(), tv->axis(-2), n), "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); TORCH_INTERNAL_ASSERT( - canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::K), - tv->axis(-1), - k), + canValidateIsInnerDim(k_dims.back(), tv->axis(-1), k), "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); } -void validateResultInnerMN(TensorView* tv, int m, int n) { +//! Utility function to help check that the innermost 3 iterdomains +//! are also the corresponding innermost {m,n} dimensions of +//! the root id's that are participating in the mma operation. +//! This is a format check before the warp mma swizzler applies mma +//! swizzles to make sure that the swizzler is applying the right +//! swizzles to the right axes. +//! This check will be relaxed as we build out the mma usage patterns. +void validateMmaRootInnerMN(TensorView* tv, MmaOptions options, int m, int n) { + auto m_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::M); + auto n_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::N); + + TORCH_CHECK( + !m_dims.empty() && !n_dims.empty(), + "validateMmaRootInnerMNK: MMA Axes incomplete"); + + // Still check the innermost dims of each at the current state: TORCH_INTERNAL_ASSERT(tv->nDims() >= 2); - int root_dim = tv->getMaybeRFactorDomain().size(); TORCH_INTERNAL_ASSERT( - canValidateIsInnerDim( - tv->getMaybeRFactorDomain()[root_dim - 2], tv->axis(-2), m), + canValidateIsInnerDim(m_dims.back(), tv->axis(-2), m), "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); TORCH_INTERNAL_ASSERT( - canValidateIsInnerDim( - tv->getMaybeRFactorDomain()[root_dim - 1], tv->axis(-1), n), + canValidateIsInnerDim(n_dims.back(), tv->axis(-1), n), "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); } @@ -341,7 +417,7 @@ void scheduleVoltaA(TensorView* tv, MmaOptions options) { // [..., 16, 16 ,4] // [..., M, BN, K] // Some validation: - validateInnerMNK(tv, options, 16, 16, 4); + validateMmaRootInnerMNK(tv, options, 16, 16, 4); bool transposed = isOperandTransposed(options); tv->split(-3, 4); @@ -377,7 +453,7 @@ void scheduleVoltaB(TensorView* tv, MmaOptions options) { // [..., 16,16,4] // [..., BM, N, K] // Some validation: - validateInnerMNK(tv, options, 16, 16, 4); + validateMmaRootInnerMNK(tv, options, 16, 16, 4); bool transposed = isOperandTransposed(options); tv->split(-3, 16); @@ -424,18 +500,16 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) { if (options.operand == MmaOptions::Operand::A) { TORCH_INTERNAL_ASSERT(tv->nDims() >= 2); // validation: + auto m_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::M); + auto k_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::K); + TORCH_INTERNAL_ASSERT( - canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::M), - tv->axis(-2), - 16), + canValidateIsInnerDim(m_dims.back(), tv->axis(-2), 16), "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); TORCH_INTERNAL_ASSERT( - canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::K), - tv->axis(-1), - 16), - "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); + canValidateIsInnerDim(k_dims.back(), tv->axis(-1), 16), + "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain", + tv->toString()); //[16m, 16k] tv->split(-2, 8); @@ -458,18 +532,15 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) { tv->axis(-2)->parallelize(ParallelType::TIDx); } else if (options.operand == MmaOptions::Operand::B) { + auto n_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::N); + auto k_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::K); + // validation: TORCH_INTERNAL_ASSERT( - canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::N), - tv->axis(-2), - 8), + canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 8), "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); TORCH_INTERNAL_ASSERT( - canValidateIsInnerDim( - getMmaOperandRootDimension(tv, options, MmaDimension::K), - tv->axis(-1), - 16), + canValidateIsInnerDim(k_dims.back(), tv->axis(-1), 16), "MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain"); if (transposed) { @@ -539,9 +610,9 @@ void WarpMmaSwizzler::scheduleVoltaM16N16K4Fp32Output( // Make sure instruction tile size is correct. if (is_reduction) { - validateInnerMNK(tv, options, 16, 16, 4); + validateMmaRootInnerMNK(tv, options, 16, 16, 4); } else { - validateResultInnerMN(tv, 16, 16); + validateMmaRootInnerMN(tv, options, 16, 16); } int m_pos = is_reduction ? -3 : -2; @@ -603,9 +674,9 @@ void WarpMmaSwizzler::scheduleTuringM16N8K16MmaWarpOutput( // Make sure instruction tile size is correct. if (is_reduction) { - validateInnerMNK(tv, options, 16, 8, 16); + validateMmaRootInnerMNK(tv, options, 16, 8, 16); } else { - validateResultInnerMN(tv, 16, 8); + validateMmaRootInnerMN(tv, options, 16, 8); } int m_pos = is_reduction ? -3 : -2; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp index 6f3b736579d93..656db1c0ed805 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp @@ -983,9 +983,8 @@ TORCH_CUDA_CU_API void schedulePersistentKernel( // Cache inputs if unrolled auto cached_inputs = scheduler_utils::cacheInputs(fusion, unroll); - // Cache and fork outputs - std::vector> cached_outputs = - scheduler_utils::cacheAndForkOutputs(fusion, unroll); + // Cache and fork outputs + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, unroll); // Make sure we don't have global memory set on intermediate tensors from // fusion segmentation diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 75172a4357331..454f9433d6dfa 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -13,6 +14,9 @@ #include +#include +#include + namespace torch { namespace jit { namespace fuser { @@ -592,15 +596,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { ir_utils::getReductionOps(fusion /*, ignore_trivial=true */).empty(), "This scheduler only handles pointwise ops."); - // For intermediate outputs, apply cacheFork - auto outs = fusion->outputs(); - for (const auto output : outs) { - if (!output->uses().empty() && output->definition() != nullptr) { - if (output->getValType().value() == ValType::TensorView) { - output->as()->cacheFork(); - } - } - } + // Cache inputs + auto cached_inputs = scheduler_utils::cacheInputs(fusion, true); + + // Cache and fork outputs + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true); std::vector input_tvs; { @@ -637,47 +637,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { reference_tv != nullptr, "Could not find a fully broadcasted output to reference schedule on."); - IterDomain* inner_most_id = nullptr; - for (auto it = reference_tv->domain()->domain().rbegin(); - it != reference_tv->domain()->domain().rend(); - it++) { - if ((*it)->isReduction()) { - continue; - } - if ((*it)->isBroadcast() && inner_most_id == nullptr) { - inner_most_id = *it; - } - inner_most_id = *it; - break; - } - - TORCH_INTERNAL_ASSERT(inner_most_id != nullptr); - - // Caches of inputs - std::vector cached_inputs; - - // Output, cacheBefore of output - std::vector> cached_outputs; - - // Track what should be vectorized versus unrolled - std::unordered_set vectorized_tensor; - - // Figure out which inputs to cache for unrolling or vectorization - for (auto inp : input_tvs) { - if (inp->uses().empty() || inp->isFusionOutput()) { - continue; - } - cached_inputs.emplace_back(inp->cacheAfter()); - } - - // Figure out which outputs to cache for unrolling or vectorization - for (auto out : output_tvs) { - if (out->definition() == nullptr) { - continue; - } - cached_outputs.emplace_back(std::make_pair(out, out->cacheBefore())); - } - auto all_tvs = ir_utils::allTvs(fusion); // Merge right side of break point @@ -716,6 +675,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } } + int64_t unswitch_pos; if (params.break_point) { // 2D parallelization scheme TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i >= 0); @@ -768,8 +728,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx] reference_tv->split(0, 65535); reference_tv->axis(1)->parallelize(ParallelType::BIDy); + unswitch_pos = 5; } else { reference_tv->axis(0)->parallelize(ParallelType::BIDy); + unswitch_pos = 4; } } else { // [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx] @@ -779,8 +741,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx] reference_tv->split(1, 65535); reference_tv->axis(2)->parallelize(ParallelType::BIDy); + unswitch_pos = 5; } else { reference_tv->axis(1)->parallelize(ParallelType::BIDy); + unswitch_pos = 4; } } } else { @@ -792,8 +756,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [i-remainder, BIDy{65535} | BIDx | Unswitch, Unroll, TIDx] reference_tv->split(0, 65535); reference_tv->axis(1)->parallelize(ParallelType::BIDy); + unswitch_pos = 4; } else { reference_tv->axis(0)->parallelize(ParallelType::BIDy); + unswitch_pos = 3; } } else { // [BIDx | BIDy | Unswitch, Unroll, TIDx] @@ -802,8 +768,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // [BIDx | i-remainder, BIDy{65535} | Unswitch, Unroll, TIDx] reference_tv->split(1, 65535); reference_tv->axis(2)->parallelize(ParallelType::BIDy); + unswitch_pos = 4; } else { reference_tv->axis(1)->parallelize(ParallelType::BIDy); + unswitch_pos = 3; } } } @@ -848,10 +816,12 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { reference_tv->axis(1)->parallelize(ParallelType::Unswitch); reference_tv->axis(3)->parallelize(ParallelType::TIDx); } + unswitch_pos = 2; } TransformPropagator propagator(reference_tv); - MaxRootDomainInfoSpanningTree(reference_tv).traverse(&propagator); + MaxRootDomainInfoSpanningTree spanning_tree(reference_tv); + spanning_tree.traverse(&propagator); scheduler_utils::parallelizeAllLike(reference_tv, all_tvs); if (params.vectorize) { @@ -886,84 +856,31 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { } } - // Compute at into cached inputs - std::vector consumers_of_cached_inputs; - // Cache of input, and one of its consumers - std::vector> input_cache_and_consumer; - { - // Avoid duplicate additions, so track what we add - std::unordered_set added; - for (auto cached_input : cached_inputs) { - auto consumer_tvs = ir_utils::consumerTvsOf(cached_input); - TORCH_INTERNAL_ASSERT( - consumer_tvs.size(), - "Input was not succesfully filtered out for scheduling but wasn't used."); - - // Grab a consumer which will be used for computeAt structure of cached - // input into a consumer - input_cache_and_consumer.emplace_back( - std::make_pair(cached_input, consumer_tvs[0])); - - // Grab all consumers which will be used for inlining computeAt for the - // body of the computation (excluding caching inputs/outputs) - for (auto consumer_tv : consumer_tvs) { - // Don't duplicate - if (added.insert(consumer_tv).second) { - consumers_of_cached_inputs.emplace_back(consumer_tv); - } - } - } - } - - for (auto entry : input_cache_and_consumer) { - // Compute at inside unswitch position: - auto input_cache = entry.first; - auto input_cache_consumer = entry.second; - - auto unswitch_it = std::find_if( - input_cache_consumer->domain()->domain().begin(), - input_cache_consumer->domain()->domain().end(), - [](IterDomain* id) { - return id->getParallelType() == ParallelType::Unswitch; - }); - auto unswitch_pos = - unswitch_it == input_cache_consumer->domain()->domain().end() - ? -1 - : std::distance( - input_cache_consumer->domain()->domain().begin(), unswitch_it) + - 1; + // Begin by inlining at the unswitch position for the entire DAG. The cached + // inputs, and outputs will keep this inline position, but other tensors will + // get a higher position in later inline propagation. + InlinePropagator inline_unswitch( + reference_tv, unswitch_pos, ComputeAtMode::BestEffort); + spanning_tree.traverse(&inline_unswitch); - input_cache->computeAt( - input_cache_consumer, unswitch_pos, ComputeAtMode::BestEffort); + // Inline at the inner most position. The CA position of all tensors except + // inputs, cached inputs and outputs will be updated. + std::unordered_set inner_most_tensors( + all_tvs.begin(), all_tvs.end()); + for (auto cached_input : cached_inputs) { + inner_most_tensors.erase(cached_input); } - - // Producers for inlined computeAt - std::vector compute_from = consumers_of_cached_inputs; - - // Consumers for inlined computeAt - std::vector compute_to; - // Compute at cached outputs - //[BIDx, Unswitch, Vectorization, TIDx] for (auto entry : cached_outputs) { - auto cached_output = entry.second; - auto output = entry.first; - - auto unswitch_it = std::find_if( - output->domain()->domain().begin(), - output->domain()->domain().end(), - [](IterDomain* id) { - return id->getParallelType() == ParallelType::Unswitch; - }); - auto unswitch_pos = unswitch_it == output->domain()->domain().end() - ? -1 - : std::distance(output->domain()->domain().begin(), unswitch_it) + 1; - - cached_output->computeAt(output, unswitch_pos, ComputeAtMode::BestEffort); - compute_to.push_back(cached_output); + auto output = entry.second; + inner_most_tensors.erase(output); } + InlinePropagator inline_inner_most( + reference_tv, -1, ComputeAtMode::BestEffort, inner_most_tensors); + spanning_tree.traverse(&inline_inner_most); - scheduler_utils::computeAtBetween( - compute_from, compute_to, -1, ComputeAtMode::BestEffort); + // Fix max producer position + MaxProducerPosUpdater updater; + spanning_tree.traverse(&updater); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp index 1696242f4ff28..84a78bcf927d1 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp @@ -1002,9 +1002,8 @@ void scheduleReduction(Fusion* fusion, const ReductionParams& rparams) { // Cache inputs if unrolled auto cached_inputs = scheduler_utils::cacheInputs(fusion, unroll); - // Cache and fork outputs - std::vector> cached_outputs = - scheduler_utils::cacheAndForkOutputs(fusion, unroll); + // Cache and fork outputs + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, unroll); // Make sure we don't have global memory set on intermediate tensors from // fusion segmentation diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index 56310f226fde6..20869353c201f 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -1499,32 +1499,64 @@ void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) { auto instruction_tile = tile.instruction_tile; TORCH_CHECK( - warp_tile.k == cta_tile.k, - "schedule warp tile: currently no support for splitting k dimension to different warps"); + cta_tile.k % warp_tile.k == 0, + "Number of warp on k dimension need to be integer"); + + int num_warp_k = cta_tile.k / warp_tile.k; mma_util::checkDimSize( tv, {-3, -2, -1}, {cta_tile.m, cta_tile.n, cta_tile.k}); - // -3 -2 -1 - //[... M, N, K] - - // Distribute warp tile: - tv->split(-3, warp_tile.m); - tv->split(-2, warp_tile.n); + if (num_warp_k == 1) { + // Non split K over warp case: - // -5 -4 -3 -2 -1 - // [Mwo Mw Nwo Nw K] - tv->split(-4, instruction_tile.m); - tv->split(-2, instruction_tile.n); - tv->split(-1, instruction_tile.k); + // -3 -2 -1 + //[... M, N, K] + // Distribute warp tile: + tv->split(-3, warp_tile.m); + tv->split(-2, warp_tile.n); - // -8 -7 -6 -5 -4 -3 -2 -1 - // [Mwo Mw Mi Nwo Nw Ni Ko Ki] + // -5 -4 -3 -2 -1 + // [Mwo Mw Nwo Nw K] + tv->split(-4, instruction_tile.m); + tv->split(-2, instruction_tile.n); + tv->split(-1, instruction_tile.k); - tv->reorder({{-7, -5}, {-6, -3}, {-5, -7}, {-3, -2}, {-2, -6}}); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mwo Mw Mi Nwo Nw Ni Ko Ki] - // -8 -7 -6 -5 -4 -3 -2 -1 - // [Mwo Nwo Ko Mw Nw Mi Ni Ki] + tv->reorder({{-7, -5}, {-6, -3}, {-5, -7}, {-3, -2}, {-2, -6}}); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mwo Nwo Ko Mw Nw Mi Ni Ki] + } else { + // Split K over warp case: + // Main difference is that an additional + // thread dimension needs to be reserved + // for cross warp reduction: + // -3 -2 -1 + //[... M, N, K] + // Distribute warp tile: + tv->split(-3, warp_tile.m); + tv->split(-2, warp_tile.n); + tv->split(-1, warp_tile.k); + + // -6 -5 -4 -3 -2 -1 + // [Mwo Mw Nwo Nw K, Kw] + tv->split(-5, instruction_tile.m); + tv->split(-3, instruction_tile.n); + tv->split(-1, instruction_tile.k); + + // -9 -8 -7 -6 -5 -4 -3 -2 -1 + // [Mwo Mw Mi Nwo Nw Ni Kwo Kw Ki] + + tv->reorder({{-8, -6}, {-7, -3}, {-6, -8}, {-4, -2}, {-3, -7}, {-2, -4}}); + // -9 -8 -7 -6 -5 -4 -3 -2 -1 + // [Mwo Nwo Ko Mw Nw Kw, Mi Ni Ki] + + tv->merge(-9); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [MNwo Ko Mw Nw Kw, Mi Ni Ki] + } } void scheduleWarpTileWithNoReduction(TensorView* tv, MatMulTileOptions tile) { @@ -1536,6 +1568,12 @@ void scheduleWarpTileWithNoReduction(TensorView* tv, MatMulTileOptions tile) { mma_util::checkDimSize(tv, {-2, -1}, {cta_tile.m, cta_tile.n}); + TORCH_CHECK( + cta_tile.k % warp_tile.k == 0, + "Number of warp on k dimension need to be integer"); + + int num_warp_k = cta_tile.k / warp_tile.k; + // -2 -1 //[... M, N] @@ -1555,6 +1593,14 @@ void scheduleWarpTileWithNoReduction(TensorView* tv, MatMulTileOptions tile) { // -6 -5 -4 -3 -2 -1 // [Mwo Nwo Mw Nw Mi Ni] + + if (num_warp_k != 1) { + // The non reduction warps are merged together + // to save one thread dim for cross dim reduce. + tv->merge(-6); + // -5 -4 -3 -2 -1 + // [MNo Mw Nw Mi Ni] + } } //! Split the innermost dim to a vectorized load @@ -1568,9 +1614,21 @@ void scheduleContiguousVectorLoad( tv->split(-1, num_of_thread * vector_word); tv->split(-1, vector_word); // [..., thread, vec] - // distribute to warp: + // distribute to warp: for tidx tv->split(-2, 32); - tv->split(-3, warp_dims.n * warp_dims.k); + + // -3 -2 -1 + // [...warp, lane, vec] + + if (warp_dims.k == 1) { + // -4 -3 -2 -1 + // [...warpM, warpN, lane, vec] + tv->split(-3, warp_dims.n); + } else { + // -4 -3 -2 -1 + // [...warpMN, warpR, lane, vec] + tv->split(-3, warp_dims.k); + } tv->axis(-1)->parallelize(ParallelType::Vectorize); tv->axis(-2)->parallelize(ParallelType::TIDx); diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 71983b1f162c9..5e1ecf1a940a4 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -210,7 +211,8 @@ TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner) memory_type_(src->memory_type_), swizzle_type_(src->swizzle_type_), is_double_buffered_(src->is_double_buffered_), - cpu_scalar_(src->cpu_scalar_) { + cpu_scalar_(src->cpu_scalar_), + has_swizzle_op_(src->has_swizzle_op_) { for (const auto id : src->axesToSwizzle()) { axes_to_swizzle_.push_back(ir_cloner->clone(id)); } @@ -569,6 +571,85 @@ TensorView* TensorView::swizzle( return this; } +TensorView* TensorView::swizzle(Swizzle2DType swizzle_type, int x, int y) { + has_swizzle_op_ = true; + if (x < 0) { + x += domain()->nDims(); + } + if (y < 0) { + y += domain()->nDims(); + } + + TORCH_CHECK( + x >= (int)getComputeAtPosition(), + false, + "Cannot swizzle axes within compute at position. Axis ", + x, + " is within computeAtPosition = ", + getComputeAtPosition()); + + TORCH_CHECK( + y >= (int)getMaxProducerPosition(), + "Cannot swizzle axes within max producer position. Axis ", + y, + " is within maxProducerPosition = ", + getMaxProducerPosition()); + + // Disable unsupported use cases at the current step. + // Currently do not support reducing or broadcasting + // swizzled dimensions. + auto all_inputs = InputsOf::outputs(fusion(), {axis(x), axis(y)}); + for (auto id : ir_utils::filterByType(all_inputs)) { + TORCH_INTERNAL_ASSERT( + !id->isBroadcast() && !id->isReduction(), + "Unsupported use case for swizzle."); + } + + // Also checking that the scheduler is not trying to + // compose swizzles, which is not yet supported either. + auto all_exprs = DependencyCheck::getAllValsBetween( + {all_inputs.begin(), all_inputs.end()}, {axis(x), axis(y)}); + for (auto expr : all_exprs) { + TORCH_INTERNAL_ASSERT( + !expr->isA(), "Composing swizzles is not yet supported"); + } + + // Check swizzle specific constraints on the input axes: + if (swizzle_type != Swizzle2DType::ZShape) { + ExpressionEvaluator const_eval(fusion()); + + auto x_id = axis(x); + auto y_id = axis(y); + + TORCH_INTERNAL_ASSERT( + x_id->extent()->isConstInt() && y_id->extent()->isConstInt(), + "Only constant iterdomains supported on given swizzle type"); + + int in_x_size = x_id->extent()->evaluateInt(); + int in_y_size = y_id->extent()->evaluateInt(); + + // Check size constraints based on swizzle type + if (swizzle_type == Swizzle2DType::Transpose || + swizzle_type == Swizzle2DType::XOR) { + TORCH_INTERNAL_ASSERT( + in_x_size == in_y_size, "Swizzle: equal dim iterdomains only"); + } + + if (swizzle_type == Swizzle2DType::Scatter) { + TORCH_INTERNAL_ASSERT( + in_y_size == 4, "Swizzle: unsupported id size must be 4 ", in_y_size); + TORCH_INTERNAL_ASSERT( + in_x_size == 8 || in_x_size == 16 || in_x_size == 32, + "Swizzle: unsupported id size must be 8, 16, or 32 ", + in_x_size); + } + } + + domain()->swizzle(swizzle_type, x, y); + + return this; +} + TensorView* TensorView::rFactor(const std::vector& axes) { TORCH_INTERNAL_ASSERT( !container()->isA(), @@ -581,11 +662,11 @@ TensorView* TensorView::rFactor(const std::vector& axes) { // !hasComputeAt(), "Cannot rfactor tensors after compute at has been // set."); TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to rFactor a 0-dim TensorView"); - TORCH_INTERNAL_ASSERT(definition()->isA()); FusionGuard fg(fusion()); TORCH_CHECK( definition() != nullptr && - definition()->getExprType() == ExprType::ReductionOp, + (definition()->getExprType() == ExprType::ReductionOp || + definition()->getExprType() == ExprType::MmaOp), "Error rfactoring ", this, " its definition is either a nullptr or not a reduction."); @@ -596,8 +677,6 @@ TensorView* TensorView::rFactor(const std::vector& axes) { !definition()->isA(), "For GroupedReducitonOp, use TensorView::rFactor(const std::vector& axes, const std::vector& tvs)"); - ReductionOp* this_definition = definition()->as(); - // Split tensor view into 2 parts auto domain_pair = domain()->rFactor(axes); @@ -614,21 +693,38 @@ TensorView* TensorView::rFactor(const std::vector& axes) { setDomain(consumer_domain); TensorView* consumer = this; - // Setup dependency chain, inserting producer before this op. - // Expr* producer_definition = - IrBuilder::create( - this_definition->getReductionOpType(), - this_definition->init(), - producer, - this_definition->in()); - - // Expr* consumer_definition = - IrBuilder::create( - this_definition->getReductionOpType(), - this_definition->init(), - consumer, - producer); + if (auto this_reduction = dynamic_cast(definition())) { + // Setup dependency chain, inserting producer before this op. + // Expr* producer_definition = + IrBuilder::create( + this_reduction->getReductionOpType(), + this_reduction->init(), + producer, + this_reduction->in()); + // Expr* consumer_definition = + IrBuilder::create( + this_reduction->getReductionOpType(), + this_reduction->init(), + consumer, + producer); + } else if (auto this_mma = dynamic_cast(definition())) { + // Initial reduction that still uses mma to combine + // the input. + IrBuilder::create( + producer, + this_mma->inA(), + this_mma->inB(), + this_mma->init(), + this_mma->options()); + + // Remaining reduction that can be scheduled cross + // warp or cta. + IrBuilder::create( + BinaryOpType::Add, this_mma->init(), consumer, producer); + } else { + TORCH_INTERNAL_ASSERT(false, "RFactor: unsupported tensor definition"); + } return producer; } diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 215869911344a..18f145e71ce9e 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -196,6 +196,42 @@ class UnswitchInElseChecker : public kir::IrVisitor { } }; +// Basically just TransformPropagator, except that it checks the consistency +// replayPasC with getMatchedLeafPosWithoutReplayPasC, replayCasP with +// getMatchedLeafPosWithoutReplayCasP, and fullSelfReplay with fullSelfMatching: +// - After replayPasC, getMatchedLeafPosWithoutReplayPasC should return the same +// replayed position +// - After replayCasP, getMatchedLeafPosWithoutReplayCasP should return the same +// replayed position +// - After fullSelfReplay, fullSelfMatching should return true +struct TransformPropagatorWithCheck : public TransformPropagator { + public: + virtual void propagateC2P(TensorView* from, TensorView* to) override { + TransformPropagator::propagateC2P(from, to); + auto from_pos = replayed_pos_.at(from); + auto to_pos = replayed_pos_.at(to); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayPasC( + to, from, from_pos) == to_pos); + } + virtual void propagateP2C(TensorView* from, TensorView* to) override { + TransformPropagator::propagateP2C(from, to); + auto from_pos = replayed_pos_.at(from); + auto to_pos = replayed_pos_.at(to); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayCasP( + to, from, from_pos) == to_pos); + } + virtual void propagateSibling(TensorView* from, TensorView* to) override { + TransformPropagator::propagateSibling(from, to); + auto from_pos = replayed_pos_.at(from); + auto to_pos = replayed_pos_.at(to); + TORCH_CHECK(from_pos == to_pos); + TORCH_CHECK(TransformReplay::fullSelfMatching(from, to)); + } + using TransformPropagator::TransformPropagator; +}; + } // namespace // 1. Test cases are void() functions. @@ -1326,26 +1362,26 @@ TEST_F(NVFuserTest, FusionParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { - int64_t i51; - i51 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); - if ((i51 < T0.size[0])) { + int64_t i50; + i50 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); + if ((i50 < T0.size[0])) { float T5[1]; T5[0] = 0; T5[0] - = T1[i51]; + = T1[i50]; float T4[1]; T4[0] = 0; T4[0] - = T0[i51]; - float T6[1]; + = T0[i50]; float T2[1]; T2[0] = T4[0] * T5[0]; + float T6[1]; T6[0] = T2[0] * T4[0]; - T3[i51] + T3[i50] = T6[0]; } } @@ -9407,6 +9443,13 @@ TEST_F(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) { const int64_t dimx = 1024; const int64_t dimy = 16384; + auto properties = at::cuda::getDeviceProperties(0); + // Require 70KB of smem to run test + const size_t required_smem_size = 70 << 10; + if (properties->sharedMemPerBlockOptin < required_smem_size) { + GTEST_SKIP() << "not enough shared memory space on device to run test"; + } + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({dimx, dimy}, options); at::Tensor aten_static_in = aten_input.narrow(1, 0, static_size); @@ -9604,6 +9647,14 @@ TEST_F(NVFuserTest, FusionPersistentNormLocalShared_CUDA) { torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion, aten_inputs); + + auto properties = at::cuda::getDeviceProperties(0); + // Require 70KB of smem to run test + const size_t required_smem_size = 70 << 10; + if (properties->sharedMemPerBlockOptin < required_smem_size) { + GTEST_SKIP() << "not enough shared memory space on device to run test"; + } + fe.runFusion(aten_inputs, {cg_static_out, cg_dynamic_out}); auto at_mu = at::mean(aten_input.to(at::kDouble), -1).unsqueeze(1); @@ -13807,7 +13858,7 @@ TEST_F(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { tv3->reorder({{4, 2}}); // [bidx, unswitch, vectorize{2}, unroll{2}, tidx] - TransformPropagator propagator(tv3); + TransformPropagatorWithCheck propagator(tv3); MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); @@ -14446,7 +14497,7 @@ TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) { // Split inner-most dim tv2->split(-1, kNumElems); tv2->split(-1, kVecSize); - TransformPropagator propagator(tv2); + TransformPropagatorWithCheck propagator(tv2); MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); c0->computeAt(tv2, -2); @@ -14509,7 +14560,7 @@ TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicFail_CUDA) { } tv2->split(-1, kNumElems); tv2->split(-1, kVecSize); - TransformPropagator propagator(tv2); + TransformPropagatorWithCheck propagator(tv2); MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); c0->computeAt(tv2, -2); @@ -15169,7 +15220,7 @@ TEST_F(NVFuserTest, FusionValidateParallelize6_CUDA) { tv4->split(0, 3); tv4->split(0, 2); - TransformPropagator propagator(tv4); + TransformPropagatorWithCheck propagator(tv4); MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); tv0->computeAt(tv2, 2); @@ -16540,7 +16591,7 @@ TEST_F(NVFuserTest, FusionSimpleWarp_CUDA) { tv1->split(1, 32); auto tv1_rf = tv1->rFactor({1}); - TransformPropagator propagator(tv1_rf); + TransformPropagatorWithCheck propagator(tv1_rf); MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(0)->parallelize(ParallelType::BIDx); @@ -16587,7 +16638,7 @@ TEST_F(NVFuserTest, FusionSimpleWarpPad_CUDA) { tv1_rf->axis(-1)->padToMultipleOfWarp(32); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(32); - TransformPropagator propagator(tv1_rf); + TransformPropagatorWithCheck propagator(tv1_rf); MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0->axis(-1)->padToMultipleOfWarp(32); @@ -16636,7 +16687,7 @@ TEST_F(NVFuserTest, FusionWarpPadMergeSplit_CUDA) { tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(); - TransformPropagator propagator(tv1_rf); + TransformPropagatorWithCheck propagator(tv1_rf); MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); @@ -16678,7 +16729,7 @@ TEST_F(NVFuserTest, FusionSerialWarpReduction_CUDA) { tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(); - TransformPropagator propagator(tv1); + TransformPropagatorWithCheck propagator(tv1); MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); @@ -16723,7 +16774,7 @@ TEST_F(NVFuserTest, FusionTrivialWarpReduction_CUDA) { tv1_rf->axis(-2)->parallelize(ParallelType::TIDx); tv1->axis(-2)->parallelize(ParallelType::TIDx); tv1->axis(-2)->padToMultipleOfWarp(); - TransformPropagator propagator(tv1_rf); + TransformPropagatorWithCheck propagator(tv1_rf); MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); tv0->axis(-2)->parallelize(ParallelType::TIDx); tv0_cache->axis(-2)->parallelize(ParallelType::TIDx); @@ -16771,7 +16822,7 @@ TEST_F(NVFuserTest, FusionMultipleDimBinding_CUDA) { tv1_rf->axis(-1)->padToMultipleOfWarp(32); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(32); - TransformPropagator propagator(tv1_rf); + TransformPropagatorWithCheck propagator(tv1_rf); MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0->axis(-1)->padToMultipleOfWarp(32); @@ -16854,7 +16905,7 @@ TEST_F(NVFuserTest, FusionWarpMutipleThreadDim_CUDA) { tv2_rf->axis(-1)->parallelize(ParallelType::TIDx); tv2_rf->axis(-1)->padToMultipleOfWarp(); - TransformPropagator propagator(tv2_rf); + TransformPropagatorWithCheck propagator(tv2_rf); MaxRootDomainInfoSpanningTree(tv2_rf).traverse(&propagator); tv0->axis(-1)->parallelize(ParallelType::TIDx); @@ -16901,7 +16952,7 @@ TEST_F(NVFuserTest, FusionWarpReduceUnrollOuterLoop_CUDA) { tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->padToMultipleOfWarp(); tv1->axis(1)->parallelize(ParallelType::Unroll); - TransformPropagator propagator(tv1_rf); + TransformPropagatorWithCheck propagator(tv1_rf); MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); tv0->axis(-1)->parallelize(ParallelType::TIDx); tv0->axis(1)->parallelize(ParallelType::Unroll); @@ -17104,7 +17155,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination3_CUDA) { tv1->split(0, 10); tv1->split(0, 33); - TransformPropagator propagator(tv1); + TransformPropagatorWithCheck propagator(tv1); MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); auto tv4 = tv1->rFactor({-1}); @@ -17158,7 +17209,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination4_CUDA) { tv1->split(1, 7); tv1->split(0, 11); tv1->reorder({{1, 2}, {2, 1}}); - TransformPropagator propagator(tv1); + TransformPropagatorWithCheck propagator(tv1); MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); tv1->axis(0)->parallelize(ParallelType::TIDy); @@ -17207,7 +17258,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination5_CUDA) { fusion.addOutput(tv3); tvs2.avg->split(0, 4); - TransformPropagator propagator(tvs2.avg); + TransformPropagatorWithCheck propagator(tvs2.avg); MaxRootDomainInfoSpanningTree(tvs2.avg).traverse(&propagator); auto rtvs2 = tvs2.rFactor({1}); @@ -17252,7 +17303,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination6_CUDA) { fusion.addOutput(tv4); tv4->split(1, 5); - TransformPropagator propagator(tv4); + TransformPropagatorWithCheck propagator(tv4); MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); tv4->reorder({{0, 1}, {1, 0}}); @@ -17300,7 +17351,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination7_CUDA) { tv3->split(-1, 5); tv3->split(-1, 4); tv3->split(-1, 3); - TransformPropagator propagator(tv3); + TransformPropagatorWithCheck propagator(tv3); MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 1); @@ -18938,7 +18989,7 @@ TEST_F(NVFuserTest, FusionFloatPow_CUDA) { tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(1)->parallelize(ParallelType::TIDx); - TransformPropagator propagator(tv1); + TransformPropagatorWithCheck propagator(tv1); MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); scheduler_utils::parallelizeAllLike(tv1, {tv2, tv3, tv4, tv5, tv6}); @@ -19046,9 +19097,9 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { - int64_t i172; - i172 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); - if ((i172 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { + int64_t i171; + i171 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); + if ((i171 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { __half T9[1]; T9[0] = 0; T9[0] @@ -19056,8 +19107,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, __half T8[1]; T8[0] = 0; T8[0] - = T0[i172]; - __half T10[1]; + = T0[i171]; float T3[1]; T3[0] = __half2float(T9[0]); @@ -19074,9 +19124,10 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, float T6[1]; T6[0] = relu(T5[0]); + __half T10[1]; T10[0] = __float2half(T6[0]); - T7[i172] + T7[i171] = T10[0]; } } @@ -20182,7 +20233,7 @@ TEST_F(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { tv3->split(0, 8, false); tv3->split(1, 4); - TransformPropagator propagator(tv3); + TransformPropagatorWithCheck propagator(tv3); MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv3->axis(1)->parallelize(ParallelType::TIDx); @@ -20345,7 +20396,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering1_CUDA) { tv3->split(-1, 128); tv3->split(-1, 32); - TransformPropagator propagator(tv3); + TransformPropagatorWithCheck propagator(tv3); MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 1); @@ -20383,7 +20434,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering2_CUDA) { tv3->split(-1, 128); tv3->split(-1, 32); - TransformPropagator propagator(tv3); + TransformPropagatorWithCheck propagator(tv3); MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, -1); @@ -20423,7 +20474,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering3_CUDA) { tv3->split(-1, 128); tv3->split(-1, 32); - TransformPropagator propagator(tv3); + TransformPropagatorWithCheck propagator(tv3); MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 1); @@ -20472,7 +20523,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering4_CUDA) { tv3->split(-1, 128); tv3->split(-1, 32); tv3->split(-1, 8); - TransformPropagator propagator(tv3); + TransformPropagatorWithCheck propagator(tv3); MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 2); @@ -20514,7 +20565,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering5_CUDA) { tv2->split(-1, 128); tv2->split(-1, 32); tv2->split(-1, 8); - TransformPropagator propagator(tv2); + TransformPropagatorWithCheck propagator(tv2); MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); tv0->computeAt(tv2, 2); @@ -20558,7 +20609,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering6_CUDA) { tv3->split(-1, 16); tv3->split(-2, 4); tv3->split(-2, 2); - TransformPropagator propagator(tv3); + TransformPropagatorWithCheck propagator(tv3); MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv0->computeAt(tv3, 1); @@ -20596,7 +20647,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering7_CUDA) { tv2->split(-1, 128); tv2->split(-1, 4); - TransformPropagator propagator(tv2); + TransformPropagatorWithCheck propagator(tv2); MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); tv1->computeAt(tv2, 2); @@ -20637,7 +20688,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering8_CUDA) { tv4->split(0, 32); tv4->split(0, 4); - TransformPropagator propagator(tv4); + TransformPropagatorWithCheck propagator(tv4); MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); tv0->computeAt(tv4, 1); @@ -20679,7 +20730,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering9_CUDA) { out->split(0, 32); out->split(0, 4); - TransformPropagator propagator(out); + TransformPropagatorWithCheck propagator(out); MaxRootDomainInfoSpanningTree(out).traverse(&propagator); tv2->setMemoryType(MemoryType::Shared); @@ -20748,7 +20799,7 @@ TEST_F(NVFuserTest, FusionSmemBlockGemmCacheDoubleBuffer_CUDA) { auto tv6_rf = tv6->rFactor({-1}); - TransformPropagator propagator(tv6_rf); + TransformPropagatorWithCheck propagator(tv6_rf); MaxRootDomainInfoSpanningTree(tv6_rf).traverse(&propagator); tv0->computeAt(tv6, 3); @@ -20815,7 +20866,7 @@ TEST_F(NVFuserTest, FusionIntermediateTensorVectorize_CUDA) { tv1->setMemoryType(mem_type); tv3->split(-1, 4); - TransformPropagator propagator(tv3); + TransformPropagatorWithCheck propagator(tv3); MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv1->computeAt(tv3, -2); @@ -21465,7 +21516,7 @@ TEST_F(NVFuserTest, FusionIndexHoist2_CUDA) { fusion.addOutput(tv5); tv5->split(-1, 4); - TransformPropagator propagator(tv5); + TransformPropagatorWithCheck propagator(tv5); MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator); tv4->split(-1, 3); @@ -21558,7 +21609,7 @@ TEST_F(NVFuserTest, FusionTestGridComm2_CUDA) { tv4->merge(0); tv4->split(0, 2); - TransformPropagator propagator(tv4); + TransformPropagatorWithCheck propagator(tv4); MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); tv3->computeAt(tv4, 1); @@ -21925,7 +21976,7 @@ TEST_F(NVFuserTest, FusionContigIndexingWithBroadcast_CUDA) { fusion.addOutput(tv3); tv3->merge(0); - TransformPropagator propagator(tv3); + TransformPropagatorWithCheck propagator(tv3); MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv2->setMemoryType(MemoryType::Local); @@ -21975,7 +22026,7 @@ TEST_F(NVFuserTest, FusionVectorizeContigIndexValidationFail2_CUDA) { tv4->merge(1, 2); tv4->merge(0, 1); tv4->split(0, 4); - TransformPropagator propagator(tv4); + TransformPropagatorWithCheck propagator(tv4); MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); tv0->computeAt(tv4, -2); @@ -22021,7 +22072,7 @@ TEST_F(NVFuserTest, FusionVectorizeContigIndexWithBroadcast_CUDA) { // Don't modify tv1 so that it's replayed as tv2 with actual // transformations. It would create temporary IterDomains, and the // validation should still be able to detect vectorization by 4 is valid. - // TransformPropagator propagator(tv3); + // TransformPropagatorWithCheck propagator(tv3); // MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); tv2->merge(1, 2); @@ -22106,7 +22157,7 @@ TEST_F(NVFuserTest, FusionTrivialReductionForwarding1_CUDA) { tv2->merge(0); tv2->split(0, 4); - TransformPropagator propagator(tv2); + TransformPropagatorWithCheck propagator(tv2); MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); // All tensors must be transformed to a 2D tensor with each axis @@ -22869,7 +22920,7 @@ TEST_F(NVFuserTest, FusionPropagateParallelTypesToSiblings_CUDA) { fusion.addOutput(tv_avg); tv_avg->split(0, 128); - TransformPropagator propagator(tv_avg); + TransformPropagatorWithCheck propagator(tv_avg); MaxRootDomainInfoSpanningTree(tv_avg).traverse(&propagator); tv_avg->axis(0)->parallelize(ParallelType::BIDx); @@ -23097,7 +23148,7 @@ TEST_F(NVFuserTest, FusionIncompleteConcreteID_CUDA) { tv6->merge(0); tv6->merge(0); - TransformPropagator propagator(tv6); + TransformPropagatorWithCheck propagator(tv6); MaxRootDomainInfoSpanningTree(tv6).traverse(&propagator); tv0->computeAt(tv6, -1, ComputeAtMode::MostInlined); @@ -23159,7 +23210,7 @@ TEST_F(NVFuserTest, FusionTestReEntrantGridWelford_CUDA) { // T2_g[iblockIdx.x, ithreadIdx.x24, rblockIdx.y, rthreadIdx.y, rS{16}, // iV25{4}] - TransformPropagator propagator(reduction_tv); + TransformPropagatorWithCheck propagator(reduction_tv); MaxRootDomainInfoSpanningTree(reduction_tv).traverse(&propagator); auto rfactor_tv = ir_utils::rfactorHelper(reduction_tv, {4}); scheduler_utils::parallelizeAllLike(rfactor_tv, ir_utils::allTvs(&fusion)); @@ -23257,6 +23308,240 @@ TEST_F(NVFuserTest, FusionRedundantPredSync_CUDA) { testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); } +// Test a basic swizzle pattern +TEST_F(NVFuserTest, FusionSimpleSwizzle0_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + + fusion.addOutput(tv2); + + // Make a 2x8 Zshape tile + tv1->split(-1, 16); + tv1->split(-1, 8); + // [O, 2, 8] + + tv2->split(-1, 16); + tv2->split(-1, 4); + //[O, 4, 4] + + tv1->computeAt(tv2, 1); + tv1->swizzle(Swizzle2DType::ZShape, -2, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({2, 32}, options); + auto t2 = t0 + 2.0; + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {t2}, __LINE__, __FILE__); +} + +// Test swizzle inlining +TEST_F(NVFuserTest, FusionSimpleSwizzle1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + + fusion.addOutput(tv3); + + // Make a 2x8 Zshape tile + tv2->split(-1, 16); + tv2->split(-1, 8); + // [O, 2, 8] + + tv3->split(-1, 16); + tv3->split(-1, 4); + //[O, 4, 4] + + tv2->computeAt(tv3, 1); + tv2->swizzle(Swizzle2DType::ZShape, -2, -1); + + // Inlining a producer into a swizzled consumer is ok + tv1->computeAt(tv2, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({2, 32}, options); + auto t3 = t0 + 3.0; + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {t3}, __LINE__, __FILE__); +} + +// Test sync insertion and memory check in parallelized swizzles. +// In this test, data is parallel written into smem in zcurve +// pattern and then read out and output to global mem unswizzled. +TEST_F(NVFuserTest, FusionSimpleSwizzle2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({32, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + + fusion.addOutput(tv2); + + tv1->swizzle(Swizzle2DType::ZShape, -2, -1); + + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv1->axis(1)->parallelize(ParallelType::TIDy); + + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::TIDy); + + // Validation should fail since TV1 is not in shared + // memory as required by sync info pass. + ASSERT_ANY_THROW(GpuLower gpulw_throw(&fusion)); + + tv1->setMemoryType(MemoryType::Shared); + + // Make sure that a sync is inserted: + bool sync_found = false; + GpuLower gpu_lw(&fusion); + auto flattened_exps = + ir_utils::flattenScopedExprs(gpu_lw.kernel()->topLevelExprs()); + + for (auto expr : flattened_exps) { + if (expr->isA()) { + sync_found = true; + } + // Will require a sync thread before any shared memory read. + for (auto inp_tv : ir_utils::filterByType(expr->inputs())) { + if (inp_tv->getMemoryType() == MemoryType::Shared) { + TORCH_INTERNAL_ASSERT( + sync_found, "Block sync required but not inserted"); + } + } + } + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({32, 32}, options); + auto t2 = t0 + 2.0; + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {t2}, __LINE__, __FILE__); +} + +// Test BestEffortReplay behavior with swizzle op +TEST_F(NVFuserTest, FusionSwizzleMapping_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + + fusion.addOutput(tv3); + + // Make a 2x8 Zshape tile + tv2->split(-1, 16); + tv2->split(-1, 8); + // [O, 2, 8] + + tv3->split(-1, 16); + tv3->split(-1, 4); + //[O, 4, 4] + + tv2->computeAt(tv3, 1); + tv2->swizzle(Swizzle2DType::ZShape, -2, -1); + + // Inlining a producer into a swizzled consumer is ok + tv1->computeAt(tv2, -1); + + // Check BestEffortReplay behavior with skip swizzles option on. + PairwiseRootDomainMap root_map(tv1, tv2); + + // Check producer to consumer map, + // i.e. unswizzled tensor to swizzled tensor map + //---------------------------------------------------------- + auto p2c = BestEffortReplay::replayCasP(tv2, tv1, -1, root_map).getReplay(); + auto swizzle_x_it0 = p2c.find(tv1->axis(-2)); + auto swizzle_y_it0 = p2c.find(tv1->axis(-1)); + // P2C map should exist and both the x and y map should + // map to the output of the swizzle op. + TORCH_INTERNAL_ASSERT( + swizzle_x_it0 != p2c.end() && swizzle_y_it0 != p2c.end()); + TORCH_INTERNAL_ASSERT( + swizzle_x_it0->second == tv2->axis(-2) && + swizzle_y_it0->second == tv2->axis(-1)); + + // Check consumer to producer map, + // i.e. swizzled tensor to unswizzled tensor map + //---------------------------------------------------------- + auto c2p = BestEffortReplay::replayPasC(tv1, tv2, -1, root_map).getReplay(); + + auto swizzle_op = tv2->axis(-1)->definition()->as(); + + // Find mapping for swizzle inputs + auto swizzle_x_it1 = c2p.find(swizzle_op->inX()); + auto swizzle_y_it1 = c2p.find(swizzle_op->inY()); + + // Find mapping for swizzle outputs + auto swizzle_x_it2 = c2p.find(swizzle_op->outX()); + auto swizzle_y_it2 = c2p.find(swizzle_op->outY()); + + // Input of swizzle ops will not be mapped to any + // by BestEffortReplay, as BestEffortReplay has to be + // one to one. IdGraph will further map them together. + TORCH_INTERNAL_ASSERT( + swizzle_x_it1 == c2p.end() && swizzle_y_it1 == c2p.end()); + + // Mapping for swizzle outputs should be mapped and should + // also map to the corresponding axes on the unswizzled tensor. + TORCH_INTERNAL_ASSERT( + swizzle_x_it2 != c2p.end() && swizzle_y_it2 != c2p.end()); + TORCH_INTERNAL_ASSERT( + swizzle_x_it2->second == tv1->axis(-2) && + swizzle_y_it2->second == tv1->axis(-1)); + + // Check id graph behavior + //---------------------------------------------------------- + ComputeAtMap ca_map(&fusion); + // Corresponding inputs and outputs of swizzle ops are + // map through by exact and permissive map. + TORCH_INTERNAL_ASSERT( + ca_map.areMapped(tv1->axis(-2), swizzle_op->inX(), IdMappingMode::EXACT)); + TORCH_INTERNAL_ASSERT( + ca_map.areMapped(tv1->axis(-1), swizzle_op->inY(), IdMappingMode::EXACT)); + TORCH_INTERNAL_ASSERT(ca_map.areMapped( + tv1->axis(-2), swizzle_op->outX(), IdMappingMode::EXACT)); + TORCH_INTERNAL_ASSERT(ca_map.areMapped( + tv1->axis(-1), swizzle_op->outY(), IdMappingMode::EXACT)); + + TORCH_INTERNAL_ASSERT(ca_map.areMapped( + tv1->axis(-2), swizzle_op->inX(), IdMappingMode::PERMISSIVE)); + TORCH_INTERNAL_ASSERT(ca_map.areMapped( + tv1->axis(-1), swizzle_op->inY(), IdMappingMode::PERMISSIVE)); + TORCH_INTERNAL_ASSERT(ca_map.areMapped( + tv1->axis(-2), swizzle_op->outX(), IdMappingMode::PERMISSIVE)); + TORCH_INTERNAL_ASSERT(ca_map.areMapped( + tv1->axis(-1), swizzle_op->outY(), IdMappingMode::PERMISSIVE)); +} + TEST_F(NVFuserTest, FusionUnsqueeze1_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -23650,6 +23935,46 @@ TEST_F(NVFuserTest, FusionReproNoncontigBroadcast_CUDA) { executor_cache.fusion(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__); } +namespace { + +// check that the resulting sibling are identical +void checkSiblingConsistency(TensorView* replay, TensorView* target) { + auto replay_root = replay->getRootDomain(); + auto replay_dom = replay->domain()->domain(); + auto target_root = target->getRootDomain(); + auto target_dom = target->domain()->domain(); + std::unordered_map target2replay_map; + TORCH_CHECK(replay_root.size() == target_root.size()); + target2replay_map.reserve(replay_root.size()); + std::transform( + target_root.begin(), + target_root.end(), + replay_root.begin(), + std::inserter(target2replay_map, target2replay_map.begin()), + [](auto a, auto b) { return std::make_pair(a, b); }); + BestEffortReplay replay_(replay_dom, target_dom, target2replay_map); + auto r = replay_.getReplay(); + for (int64_t i = 0; i < replay_dom.size(); i++) { + auto target_id = target_dom[i]; + auto replay_it = r.find(target_id); + TORCH_CHECK(replay_it != r.end()); + TORCH_CHECK( + replay_it->second == replay_dom[i], + "IterDomain mismatch when checking ", + replay, + " and ", + target, + " at ", + i, + ", got ", + replay_it->second, + " and ", + replay_dom[i]); + } +}; + +} // namespace + TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { // https://github.com/csarofeen/pytorch/issues/1760 Fusion fusion; @@ -23673,55 +23998,85 @@ TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { auto tvs2 = tvs.rFactor({1, 4}); - TransformPropagator propagator(tvs2.var_sum); + TransformPropagatorWithCheck propagator(tvs2.var_sum); MaxRootDomainInfoSpanningTree(tvs2.var_sum).traverse(&propagator); - // check that the resulting tensors in tvs2 are identical - auto checkSiblingConsistency = [](TensorView* replay, TensorView* target) { - auto replay_root = replay->getRootDomain(); - auto replay_dom = replay->domain()->domain(); - auto target_root = target->getRootDomain(); - auto target_dom = target->domain()->domain(); - std::unordered_map target2replay_map; - TORCH_CHECK(replay_root.size() == target_root.size()); - target2replay_map.reserve(replay_root.size()); - std::transform( - target_root.begin(), - target_root.end(), - replay_root.begin(), - std::inserter(target2replay_map, target2replay_map.begin()), - [](auto a, auto b) { return std::make_pair(a, b); }); - BestEffortReplay replay_(replay_dom, target_dom, target2replay_map); - auto r = replay_.getReplay(); - for (int64_t i = 0; i < replay_dom.size(); i++) { - auto target_id = target_dom[i]; - auto replay_it = r.find(target_id); - TORCH_CHECK(replay_it != r.end()); - TORCH_CHECK( - replay_it->second == replay_dom[i], - "IterDomain mismatch when checking ", - replay, - " and ", - target, - " at ", - i, - ", got ", - replay_it->second, - " and ", - replay_dom[i]); - } - }; std::vector siblings[] = { {tvs.avg, tvs.var_sum, tvs.n}, {tvs2.avg, tvs2.var_sum, tvs2.n}}; for (auto tensors : siblings) { for (auto t1 : tensors) { for (auto t2 : tensors) { - checkSiblingConsistency(t1, t2); + TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2)); } } } } +TEST_F(NVFuserTest, FusionTransformPropagateSelectorSibling_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tvs = Welford(tv0, {1}); + fusion.addOutput(tvs.var_sum); + + tvs.avg->split(1, 1); + tvs.avg->split(1, 2); + tvs.avg->split(1, 3); + tvs.var_sum->split(1, 1); + tvs.var_sum->split(1, 2); + tvs.var_sum->split(1, 3); + tvs.n->split(1, 1); + tvs.n->split(1, 2); + tvs.n->split(1, 3); + + auto tvs2 = tvs.rFactor({1, 4}); + + struct DisableTv0 : public MaxInfoSpanningTree::Selector { + TensorView* tv0; + virtual bool allowC2P(TensorView* from, TensorView* to) override { + return from != tv0 && to != tv0; + }; + virtual bool allowP2C(TensorView* from, TensorView* to) override { + return from != tv0 && to != tv0; + }; + virtual bool allowSibling(TensorView* from, TensorView* to) override { + return true; + } + DisableTv0(TensorView* tv0) : tv0(tv0) {} + } selector1(tv0); + + struct DisableTv0AndSibling : public DisableTv0 { + virtual bool allowSibling(TensorView* from, TensorView* to) override { + return false; + } + using DisableTv0::DisableTv0; + } selector2(tv0); + + TransformPropagatorWithCheck propagator(tvs2.var_sum); + MaxRootDomainInfoSpanningTree good_path(tvs2.var_sum, &selector1); + MaxRootDomainInfoSpanningTree bad_path(tvs2.var_sum, &selector2); + + auto check = [&]() { + std::vector siblings[] = { + {tvs.avg, tvs.var_sum, tvs.n}, {tvs2.avg, tvs2.var_sum, tvs2.n}}; + for (auto tensors : siblings) { + for (auto t1 : tensors) { + for (auto t2 : tensors) { + TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2)); + } + } + } + }; + + bad_path.traverse(&propagator); + ASSERT_ANY_THROW(check()); + good_path.traverse(&propagator); + check(); +} + TEST_F(NVFuserTest, FusionTransformPropagatePosition_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -23736,7 +24091,7 @@ TEST_F(NVFuserTest, FusionTransformPropagatePosition_CUDA) { tv0->merge(2); tv0->merge(0); - TransformPropagator propagator(tv0); + TransformPropagatorWithCheck propagator(tv0); MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); TORCH_CHECK(tv1->nDims() == 4); @@ -23819,7 +24174,7 @@ TEST_F(NVFuserTest, FusionIssue1770Repro_CUDA) { __FILE__); } -TEST_F(NVFuserTest, FusionTransormPropagatorSelector_CUDA) { +TEST_F(NVFuserTest, FusionTransformPropagatorSelector_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -23841,16 +24196,19 @@ TEST_F(NVFuserTest, FusionTransormPropagatorSelector_CUDA) { struct Selector : public MaxInfoSpanningTree::Selector { TensorView* tv0; TensorView* tv3; - virtual bool allowPasC(TensorView* from, TensorView* to) override { + virtual bool allowC2P(TensorView* from, TensorView* to) override { return to == tv0; } - virtual bool allowCasP(TensorView* from, TensorView* to) override { + virtual bool allowP2C(TensorView* from, TensorView* to) override { return to == tv3; } + virtual bool allowSibling(TensorView* from, TensorView* to) override { + return false; + } Selector(TensorView* tv0, TensorView* tv3) : tv0(tv0), tv3(tv3) {} } selector(tv0, tv3); - TransformPropagator propagator(tv2); + TransformPropagatorWithCheck propagator(tv2); MaxRootDomainInfoSpanningTree(tv2, &selector).traverse(&propagator); TORCH_CHECK(tv0->nDims() == 2); @@ -23860,7 +24218,7 @@ TEST_F(NVFuserTest, FusionTransormPropagatorSelector_CUDA) { TORCH_CHECK(tv4->nDims() == 1); } -TEST_F(NVFuserTest, FusionTransormPropagatorPos_CUDA) { +TEST_F(NVFuserTest, FusionTransformPropagatorPos_CUDA) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -23874,13 +24232,12 @@ TEST_F(NVFuserTest, FusionTransormPropagatorPos_CUDA) { tv1->split(-1, 3); tv1->split(-1, 5); - TransformPropagator propagator(tv1, 2); + TransformPropagatorWithCheck propagator(tv1, 2); MaxRootDomainInfoSpanningTree(tv1, 2).traverse(&propagator); - TORCH_CHECK(tv0->nDims() == 3); - TORCH_CHECK(tv0->axis(0)->extent()->evaluateInt() == 11); - TORCH_CHECK(tv0->axis(1)->extent()->evaluateInt() == 2); - TORCH_CHECK(tv0->axis(2)->extent()->evaluateInt() == 105); + auto expect = makeConcreteTensor({22, 105}); + expect->split(0, 2); + TORCH_CHECK(TransformReplay::fullSelfMatching(expect, tv0)); } TEST_F(NVFuserTest, FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA) { @@ -23899,13 +24256,18 @@ TEST_F(NVFuserTest, FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA) { struct Printer : public MaxInfoSpanningTree::Propagator { std::stringstream ss; - virtual void propagateTvPasC(TensorView* from, TensorView* to) override { - ss << "propagateTvPasC" << std::endl; + virtual void propagateC2P(TensorView* from, TensorView* to) override { + ss << "propagateC2P" << std::endl; ss << "from: " << from->name() << std::endl; ss << "to: " << to->name() << std::endl; } - virtual void propagateTvCasP(TensorView* from, TensorView* to) override { - ss << "propagateTvCasP" << std::endl; + virtual void propagateP2C(TensorView* from, TensorView* to) override { + ss << "propagateP2C" << std::endl; + ss << "from: " << from->name() << std::endl; + ss << "to: " << to->name() << std::endl; + } + virtual void propagateSibling(TensorView* from, TensorView* to) override { + ss << "propagateSibling" << std::endl; ss << "from: " << from->name() << std::endl; ss << "to: " << to->name() << std::endl; } @@ -23918,10 +24280,10 @@ TEST_F(NVFuserTest, FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA) { path.traverse(&printer2); auto expect = R"ESCAPE( -propagateTvPasC +propagateC2P from: 1 to: 0 -propagateTvCasP +propagateP2C from: 1 to: 2 )ESCAPE"; @@ -23929,6 +24291,152 @@ to: 2 TORCH_CHECK(printer2.ss.str() == expect); } +TEST_F(NVFuserTest, FusionTransformPropagatorNoOverwrite_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + auto tv1 = broadcast(tv0, {true, false, true}); + auto tv2 = sin(tv1); + fusion->addOutput(tv2); + + tv0->split(0, 2); + tv2->split(1, 2); + tv2->split(0, 4); + + MaxRootDomainInfoSpanningTree path1(tv2); + TransformPropagatorWithCheck propagator1(tv2); + path1.traverse(&propagator1); + + MaxRootDomainInfoSpanningTree path2(tv0); + TransformPropagatorWithCheck propagator2(tv0); + path2.traverse(&propagator2); + + TORCH_CHECK(tv1->axis(0)->isBroadcast()); + TORCH_CHECK(tv1->axis(1)->isBroadcast()); + TORCH_CHECK(!tv1->axis(2)->isBroadcast()); + TORCH_CHECK(!tv1->axis(3)->isBroadcast()); + TORCH_CHECK(tv1->axis(4)->isBroadcast()); + + auto expect = makeSymbolicTensor(3); + expect->split(1, 2); + expect->split(0, 4); + TORCH_CHECK(TransformReplay::fullSelfMatching(expect, tv1)); +} + +TEST_F(NVFuserTest, FusionIssue1785Repro_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeContigTensor(1); + TensorView* tv1 = makeContigTensor(2); + + // Register your inputs + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + // [B, I] + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv3, tv1); + auto tv5 = set(tv4); + + // Register your outputs + fusion.addOutput(tv5); + + tv5->split(0, 8); + tv5->split(-1, 8); + + // [Serial, TIDy, TIDX, Serial] + + tv4->computeAt(tv5, -2); + tv3->computeAt(tv4, -1); + tv2->computeAt(tv3, 0); + tv2->split(0, 8); + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv1->computeAt(tv5, -2); + + tv5->axis(1)->parallelize(ParallelType::TIDy); + tv5->axis(2)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor in1 = at::randn({16}, options); + at::Tensor in2 = at::randn({12, 16}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {in1, in2}); + auto cg_outputs = fe.runFusion({in1, in2}); + + auto tv_ref = in1 + in2; + + testValidate(&fusion, cg_outputs, {in1, in2}, {tv_ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSkipReplay_CUDA) { + { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeContigTensor(1); + TensorView* tv1 = makeContigTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {false, true}); + auto tv3 = add(tv2, tv1); + fusion.addOutput(tv3); + + tv3->split(1, 2, false); + + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + } + + { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeContigTensor(3); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0, 2}); + auto tv2 = sin(tv1); + fusion.addOutput(tv2); + + tv0->split(1, 2, false); + + TransformPropagatorWithCheck propagator(tv0); + MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); + } +} + +TEST_F(NVFuserTest, FusionInlineRepro1803_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeContigTensor(2); + + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tvs = Welford(tv1, {1}); + auto tvo = set(tvs.var_sum); + fusion.addOutput(tvo); + + tvo->split(0, 16); + tvo->axis(1)->parallelize(ParallelType::Unroll); + + tv0->computeAt(tvo, -1, ComputeAtMode::BestEffort); + + TORCH_CHECK( + tvs.var_sum->getComputeAtPosition() == tvs.avg->getComputeAtPosition()); + TORCH_CHECK( + tvs.var_sum->getComputeAtPosition() == tvs.n->getComputeAtPosition()); + TORCH_CHECK(tvs.var_sum->getComputeAtPosition() == 1); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp index 720bcbf63fc61..ead06100153de 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp @@ -117,6 +117,7 @@ void validateNoParallelBroadcastExist(kir::Kernel* kernel) { TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) { const int nx = 999; const int tidx = 128; + const int bidx = 4; if (ceilDiv(nx, tidx) > deviceSMCount()) { GTEST_SKIP() << "Not enough SMs to run this test"; @@ -135,13 +136,20 @@ TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) { fusion.addOutput(tv3); tv3->split(0, tidx); + tv3->split(0, bidx); + tv3->split(0, 1); // unswitch TransformPropagator propagator(tv3); MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(1)->parallelize(ParallelType::TIDx); + tv3->axis(0)->parallelize(ParallelType::BIDy); + tv3->axis(2)->parallelize(ParallelType::BIDx); + tv3->axis(3)->parallelize(ParallelType::TIDx); scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion)); + // Just to make sure fused_reduction and work buffers are allocated + // uniquely + tv1->axis(1)->parallelize(ParallelType::Unswitch); + GpuLower gpulw(&fusion); validateNoParallelBroadcastExist(gpulw.kernel()); @@ -1796,6 +1804,8 @@ TEST_F( tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize); tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize); + tv5->axis(-1)->parallelize(ParallelType::Group); + auto options_half = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); auto options_float = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -1828,6 +1838,338 @@ TEST_F( fe.kernel(), outputs, aten_inputs, {t11, t13}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = sum(tv1, {0}); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv0, tv3); + fusion.addOutput(tv4); + + const int vec = 2; + const int tidx = 32; + const int tidy = 8; + + tv1->split(1, vec); + tv1->split(1, tidx); + tv1->split(0, tidy); + TransformPropagator propagator(tv1); + MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + + tv1->axis(0)->parallelize(ParallelType::BIDy); + tv1->axis(1)->parallelize(ParallelType::TIDy); + tv1->axis(2)->parallelize(ParallelType::BIDx); + tv1->axis(3)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion)); + + tv2->axis(4)->parallelize(ParallelType::Group); + + // Make sure the reduction expr is converted to GroupedGridReduciton + // and the non-reduction domains of the output TV are either + // grouped or parallelized + GpuLower gpulw(&fusion); + bool validated = false; + for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { + auto grouped_grid_reduction = + dynamic_cast(expr); + if (grouped_grid_reduction == nullptr) { + continue; + } + auto out = ir_utils::getTvOutput(grouped_grid_reduction); + for (auto out_axis : out->domain()->domain()) { + auto out_axis_pt = out_axis->getParallelType(); + TORCH_CHECK( + isParallelTypeThread(out_axis_pt) || + out_axis_pt == ParallelType::Group, + "Invalid parallel type of the reduction tensor: ", + out_axis_pt, + ". Reduction output tensor: ", + out->toString()); + } + validated = true; + } + TORCH_CHECK( + validated, "Invalid lowered kernel. No GroupedGridReduction found."); + + std::vector shape({99, 101}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn(shape, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}); + + auto t0_double = t0.to(at::kDouble); + auto ref = t0_double + t0_double.sum({0}).unsqueeze(0); + + testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Test grouping of two domains +TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = sum(tv1, {0}); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv0, tv3); + fusion.addOutput(tv4); + + const int vec1 = 2; + const int vec2 = 3; + const int tidx = 16; + const int tidy = 8; + + tv1->split(1, vec1); + tv1->split(1, vec2); + tv1->split(1, tidx); + tv1->split(0, tidy); + TransformPropagator propagator(tv1); + MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + + tv1->axis(0)->parallelize(ParallelType::BIDy); + tv1->axis(1)->parallelize(ParallelType::TIDy); + tv1->axis(2)->parallelize(ParallelType::BIDx); + tv1->axis(3)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion)); + + tv2->axis(4)->parallelize(ParallelType::Group); + tv2->axis(5)->parallelize(ParallelType::Group); + + std::vector shape({99, 129}); + + // Make sure the reduction expr is converted to GroupedGridReduciton + // and the non-reduction domains of the output TV are either + // grouped or parallelized + GpuLower gpulw(&fusion); + bool validated = false; + for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { + auto grouped_grid_reduction = + dynamic_cast(expr); + if (grouped_grid_reduction == nullptr) { + continue; + } + auto out = ir_utils::getTvOutput(grouped_grid_reduction); + for (auto out_axis : out->domain()->domain()) { + auto out_axis_pt = out_axis->getParallelType(); + TORCH_CHECK( + isParallelTypeThread(out_axis_pt) || + out_axis_pt == ParallelType::Group, + "Invalid parallel type of the reduction tensor: ", + out_axis_pt, + ". Reduction output tensor: ", + out->toString()); + } + validated = true; + } + TORCH_CHECK( + validated, "Invalid lowered kernel. No GroupedGridReduction found."); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn(shape, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}); + + auto t0_double = t0.to(at::kDouble); + auto ref = t0_double + t0_double.sum({0}).unsqueeze(0); + + testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Group both expressions and iterations +TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {0}); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv1, tv3); + + auto tv5 = add(tv0, IrBuilder::create(2)); + auto tv6 = sum(tv5, {0}); + auto tv7 = broadcast(tv6, {true, false}); + auto tv8 = add(tv5, tv7); + + auto tv9 = add(tv4, tv8); + fusion.addOutput(tv9); + + groupReductions({tv2, tv6}); + + const int vec = 2; + const int tidx = 32; + const int tidy = 8; + + tv1->split(1, vec); + tv1->split(1, tidx); + tv1->split(0, tidy); + TransformPropagator propagator(tv1); + MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + + tv1->axis(0)->parallelize(ParallelType::BIDy); + tv1->axis(1)->parallelize(ParallelType::TIDy); + tv1->axis(2)->parallelize(ParallelType::BIDx); + tv1->axis(3)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion)); + + tv2->axis(4)->parallelize(ParallelType::Group); + + // Make sure the reduction expr is converted to GroupedGridReduciton + // and the non-reduction domains of the output TV are either + // grouped or parallelized + GpuLower gpulw(&fusion); + bool validated = false; + for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { + auto grouped_grid_reduction = + dynamic_cast(expr); + if (grouped_grid_reduction == nullptr) { + continue; + } + auto out = ir_utils::getTvOutput(grouped_grid_reduction); + for (auto out_axis : out->domain()->domain()) { + auto out_axis_pt = out_axis->getParallelType(); + TORCH_CHECK( + isParallelTypeThread(out_axis_pt) || + out_axis_pt == ParallelType::Group, + "Invalid parallel type of the reduction tensor: ", + out_axis_pt, + ". Reduction output tensor: ", + out->toString()); + } + validated = true; + } + TORCH_CHECK( + validated, "Invalid lowered kernel. No GroupedGridReduction found."); + + std::vector shape({99, 101}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn(shape, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}); + + auto t0_double = t0.to(at::kDouble); + auto t4 = t0_double + 1 + (t0_double + 1).sum({0}).unsqueeze(0); + auto t8 = t0_double + 2 + (t0_double + 2).sum({0}).unsqueeze(0); + auto ref = t4 + t8; + + testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// ParallelType::Group with computeAt +TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = sum(tv1, {0}); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv0, tv3); + fusion.addOutput(tv4); + + const int vec = 2; + const int tidx = 32; + const int tidy = 8; + + tv2->reorder({{0, 1}}); + tv2->split(0, vec); + tv2->split(0, tidx); + tv2->split(-1, tidy); + + TransformPropagator propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + tv2->axis(2)->parallelize(ParallelType::Group); + + // This should avoid inlining the grouped domain + tv0->computeAt(tv4, -1, ComputeAtMode::MostInlined); + + TORCH_CHECK( + tv1->getComputeAtPosition() == 2, + "Invalid computeAt position: ", + tv1->toString()); + TORCH_CHECK( + tv2->getComputeAtPosition() == 2, + "Invalid computeAt position: ", + tv2->toString()); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::TIDx); + + for (auto tv : ir_utils::allTvs(&fusion)) { + tv->axis(-2)->parallelize(ParallelType::BIDy); + tv->axis(-1)->parallelize(ParallelType::TIDy); + } + + // Make sure the reduction expr is converted to GroupedGridReduciton + // and the non-reduction domains of the output TV are either + // grouped or parallelized + GpuLower gpulw(&fusion); + bool validated = false; + for (auto expr : KernelExprVisitor::getAllExprs(gpulw.kernel())) { + auto grouped_grid_reduction = + dynamic_cast(expr); + if (grouped_grid_reduction == nullptr) { + continue; + } + auto out = ir_utils::getTvOutput(grouped_grid_reduction); + for (auto out_axis : out->domain()->domain()) { + auto out_axis_pt = out_axis->getParallelType(); + TORCH_CHECK( + isParallelTypeThread(out_axis_pt) || + out_axis_pt == ParallelType::Group, + "Invalid parallel type of the reduction tensor: ", + out_axis_pt, + ". Reduction output tensor: ", + out->toString()); + } + validated = true; + } + TORCH_CHECK( + validated, "Invalid lowered kernel. No GroupedGridReduction found."); + + std::vector shape({99, 101}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn(shape, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}); + + auto t0_double = t0.to(at::kDouble); + auto ref = t0_double + t0_double.sum({0}).unsqueeze(0); + + testValidate(fe.kernel(), outputs, {t0}, {ref}, __LINE__, __FILE__); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index 2e3254223e7df..52217ce2bd6b8 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -167,6 +167,8 @@ TEST_F(NVFuserTest, FusionVoltaMMATT_CUDA) { // Register accumulator auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); + // [M,K,N]->[M,N,K] tv0cr->reorder({{-2, -1}, {-1, -2}}); @@ -253,6 +255,8 @@ TEST_F(NVFuserTest, FusionVoltaMMATN_CUDA) { auto tv1cr = tv1cw->cacheAfter(); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); tv2c->applyMmaSwizzle( @@ -314,6 +318,8 @@ TEST_F(NVFuserTest, FusionVoltaMMANT_CUDA) { auto tv1cr = tv1cw->cacheAfter(); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); + // To MNK tv0cr->reorder({{0, 2}, {1, 0}, {2, 1}}); tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); @@ -387,6 +393,8 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTT_CUDA) { auto tv1cr = tv1cw->cacheAfter(); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); + // Make a CTA tile // ------------------------------------------------------------------ // [M,N] @@ -636,6 +644,8 @@ TEST_F(NVFuserTest, FusionVoltaMatMulTN_CUDA) { auto tv1cr = tv1cw->cacheAfter(); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); + // Make a CTA tile // ------------------------------------------------------------------ // [M,N] @@ -784,6 +794,8 @@ TEST_F(NVFuserTest, FusionVoltaMatMulNT_CUDA) { auto tv1cr = tv1cw->cacheAfter(); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); + // Make a CTA tile // ------------------------------------------------------------------ // [M,N] @@ -940,6 +952,8 @@ TEST_F(NVFuserTest, FusionAmpereMMATN_CUDA) { auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); + // [M,N,K] -> [N,M,K] tv0cr->reorder({{-2, -3}, {-3, -2}}); tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); @@ -1008,6 +1022,7 @@ TEST_F(NVFuserTest, FusionAmpereMMATT_CUDA) { auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); // [M,K,N] -> [N,M,K] tv0cr->reorder({{-3, -2}, {-2, -1}, {-1, -3}}); tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); @@ -1082,6 +1097,7 @@ TEST_F(NVFuserTest, FusionAmpereMMANT_CUDA) { tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); // [K,M,N] -> [N,M,K] tv0cr->reorder({{-3, -1}, {-1, -3}}); @@ -1164,6 +1180,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTN_CUDA) { auto tv1cr = tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); // Make a CTA tile // ------------------------------------------------------------------ @@ -1315,6 +1332,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTT_CUDA) { auto tv1cr = tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); // Make a CTA tile // ------------------------------------------------------------------ @@ -1465,6 +1483,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulNT_CUDA) { auto tv1cr = tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); // Make a CTA tile // ------------------------------------------------------------------ @@ -1655,6 +1674,7 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { // Gemm 1 accumulator reg auto tv3c = tv3->cacheBefore(); + mma_builder1.accumulatorTv(tv3c); // Gemm 2 main loop read auto tv3cw = tv3h->cacheAfter(); @@ -1665,6 +1685,7 @@ TEST_F(NVFuserTest, FusionMatmulMatmulAmpere_CUDA) { // Gemm 2 accumulator reg auto tv4c = tv4->cacheBefore(); + mma_builder2.accumulatorTv(tv4c); // General idea is inlining gemm1's main loop inside gemm2's @@ -1956,6 +1977,7 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { // Gemm 1 accumulator reg auto tv3c = tv3->cacheBefore(); + mma_builder1.accumulatorTv(tv3c); // Softmax conversion: auto tv3ccr = tv3->cacheAfter(); @@ -1971,6 +1993,7 @@ TEST_F(NVFuserTest, FusionMatmulSoftmaxMatmulAmpere_CUDA) { // Gemm 2 accumulator reg auto tv4c = tv4->cacheBefore(); + mma_builder2.accumulatorTv(tv4c); // Schedule gemm 2: // ------------------------------------------------------------------ @@ -2274,6 +2297,7 @@ TEST_F(NVFuserTest, FusionTuringMMATN_CUDA) { tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); // [M,N,K] -> [N,M,K] tv0cr->reorder({{-2, -3}, {-3, -2}}); @@ -2341,6 +2365,7 @@ TEST_F(NVFuserTest, FusionTuringMMATT_CUDA) { tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); // [M,K,N] -> [N,M,K] tv0cr->reorder({{-3, -2}, {-2, -1}, {-1, -3}}); @@ -2413,6 +2438,7 @@ TEST_F(NVFuserTest, FusionTuringMMANT_CUDA) { tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); // [K,M,N] -> [N,M,K] tv0cr->reorder({{-3, -1}, {-1, -3}}); @@ -2494,6 +2520,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulTN_CUDA) { auto tv1cr = tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); // Make a CTA tile // ------------------------------------------------------------------ @@ -2642,6 +2669,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulTT_CUDA) { auto tv1cr = tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); // Make a CTA tile // ------------------------------------------------------------------ @@ -2789,6 +2817,7 @@ TEST_F(NVFuserTest, FusionTuringMatmulNT_CUDA) { auto tv1cr = tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); // Make a CTA tile // ------------------------------------------------------------------ @@ -2938,6 +2967,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNcpAsync_CUDA) { auto tv1cw = tv1->cacheAfter(LoadStoreOpType::CpAsync); auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); // Make a CTA tile // ------------------------------------------------------------------ @@ -3036,6 +3066,851 @@ TEST_F(NVFuserTest, FusionAmpereMatmulTNcpAsync_CUDA) { TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } +TEST_F(NVFuserTest, FusionAmpereStridedBatchedMatmulTN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 511, N = 123, K = 88, B0 = 3, B1 = 5; + + // [B0 ,M, B1,K] + auto tv0 = makeContigTensor(4, DataType::Half); + // [B0, N, B1, K] + auto tv1 = makeContigTensor(4, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [B0,M,N,B1,K] + auto tv0b = broadcast(tv0, {false, false, true, false, false}); + auto tv1b = broadcast(tv1, {false, true, false, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {4}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + mma_builder.configureMma(tv2); + + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); + auto tv0cw = tv0r->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); + + // Group the BATCHED DIMS: + // -4 -3 -2 -1 + // [B0, M, N, B1] + tv2->reorder({{-3, -2}, {-2, -1}, {-1, -4}}); + + // -4 -3 -2 -1 + // [B0, B1, M,N] + + // Make a CTA tile + // ------------------------------------------------------------------ + // [B0, B1, M, N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 4 5 + // [B0, B1, Mo,M128, No, N128] + tv2->reorder({{-3, -2}, {-2, -3}}); + + // 0 1 2 3 4 5 + // [B0, B1, Mo, No, M128, N128] + + // Merge the outer dims: + tv2->merge(0); + tv2->merge(0); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({B0, M, B1, K}, options); + auto t1 = at::randn({B0, N, B1, K}, options); + + FusionExecutor fe; + + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, 0, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + // ref implementation: + auto ref_t0 = t0.permute({0, 2, 1, 3}) + .contiguous() + .view({B0 * B1, M, K}); // B0, B1, M, K + auto ref_t1 = t1.permute({0, 2, 3, 1}) + .contiguous() + .view({B0 * B1, K, N}); // B0, B1, K, N + auto ref_permuted = + ref_t0.to(at::kFloat).bmm(ref_t1.to(at::kFloat)); // B0*B1, M,N + auto ref = ref_permuted.view({B0, B1, M, N}) + .permute({0, 2, 3, 1}) + .contiguous(); // B0,M,N,B1 + TORCH_CHECK(cg_outputs[0].allclose(ref, 0.0001, 0.0001)); +} + +// Matmul test on Ampere with a view on prolog +TEST_F(NVFuserTest, FusionAmpereViewMatmulTN_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 511, N = 257, K = 88; + int Ko = 11, Ki = 8; + + // [M,Ko,Ki] + auto tv0 = makeContigTensor(3, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv0_view = view(tv0, {M, Ko, Ki}, {M, K}); + + // [M,N,K] + auto tv0b = broadcast(tv0_view, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + mma_builder.configureMma(tv2); + + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); + auto tv0cw = tv0_view->cacheAfter(); + auto tv0cr = + tv0cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::A).ldMatrix()); + auto tv1cw = tv1r->cacheAfter(); + auto tv1cr = + tv1cw->cacheAfter(mma_builder.operand(MmaOptions::Operand::B).ldMatrix()); + auto tv2c = tv2->cacheBefore(); + mma_builder.accumulatorTv(tv2c); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,Ko,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + tv0_view->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [... Mi, Ni, Ki] want [Ni, Mi, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + // Inline the view op with the shared mem write minus + // the vectorization axes for now. + tv0_view->computeAt(tv0cw, -2); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, Ko, Ki}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, 0, fe.compileFusion(&fusion, {t0, t1})); + + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = + at::native::view(t0, {M, K}).to(at::kFloat).matmul(t1.t().to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Initial test case for in-CTA split K with VoltaMMA +TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossWarp_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 120, N = 264, K = 120; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 16); + gemm_tile.instruction_tile = GemmTile(16, 16, 4); + + auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + mma_builder.configureMma(tv2); + + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); + auto tv0cw = tv0b->cacheAfter(); + auto tv0cr = tv0cw->cacheAfter(); + auto tv1cw = tv1b->cacheAfter(); + auto tv1cr = tv1cw->cacheAfter(); + auto tv2c = tv2->cacheBefore(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 3); + tv1r->computeAt(tv2c, 3); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + auto tv2c_rf = tv2c->rFactor({-9, -4, -1}); + + // tv2c_rf is the actual output of the mma op after + // Rfactoring. + mma_builder.accumulatorTv(tv2c_rf); + + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c_rf, -4); + tv1cr->computeAt(tv2c_rf, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,No,Ko,M,N,K] + tv0cw->reorder({ + {-3, -2}, + {-2, -3}, + }); + // [Mo,No,Ko,N,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [Mo,No,Ko,M,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [Mo,No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c_rf->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + tv0b->computeAt(tv0cw, -2); + tv1b->computeAt(tv1cw, -2); + + tv0cr->axis(-1)->parallelize(ParallelType::Vectorize); + tv1cr->axis(-1)->parallelize(ParallelType::Vectorize); + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c_rf->axis(0)->parallelize(ParallelType::BIDx); + tv2c_rf->axis(1)->parallelize(ParallelType::BIDy); + tv2c_rf->axis(3)->parallelize(ParallelType::TIDz); + tv2c_rf->axis(4)->parallelize(ParallelType::TIDy); + + tv2c->axis(2)->parallelize(ParallelType::TIDz); + tv2c->axis(3)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t()); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Initial test case for cross-CTA split K with VoltaMMA +TEST_F(NVFuserTest, FusionVoltaMatMulTNCrossCTA_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + int M = 120, N = 264, K = 120; + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + // Leaving both sets of mma inputs for volta outside + // currently since they need to be swizzled. + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 16, 4); + + auto mma_builder = MmaBuilder(MmaOptions::MacroType::Volta_16_16_4, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + mma_builder.configureMma(tv2); + + auto tv0r = tv0->cacheAfter(); + auto tv1r = tv1->cacheAfter(); + auto tv0cw = tv0b->cacheAfter(); + auto tv0cr = tv0cw->cacheAfter(); + auto tv1cw = tv1b->cacheAfter(); + auto tv1cr = tv1cw->cacheAfter(); + auto tv2c = tv2->cacheBefore(); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->split(-2, 2, true); + // Order K + // 0 1 2 3 4 5 6 + // [Mo,No, M128, N128, Ko, K2CTA, K32] + tv2c->reorder({{2, 4}, {3, 5}, {4, 3}, {5, 2}}); + // 0 1 2 3 4 5 6 + // [Mo,No, K2CTA, Ko M128, N128, K32] + tv0r->computeAt(tv2c, 4); + tv1r->computeAt(tv2c, 4); + + // Make warp tile: + // ------------------------------------------------------------------------- + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + auto tv2c_rf = tv2c->rFactor({-9, -6, -1}); + + // tv2c_rf is the actual output of the mma op after + // Rfactoring. + mma_builder.accumulatorTv(tv2c_rf); + + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c_rf, -4); + tv1cr->computeAt(tv2c_rf, -4); + + // Schedule gmem read and smem write: + // --------------------------------------------------------------------------- + // [Mo,No,Ko,M,N,K] + tv0cw->reorder({ + {-3, -2}, + {-2, -3}, + }); + // [Mo,No,Ko,N,M,K] + tv0cw->merge(-2); + tv0r->merge(-2); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv0r, gemm_tile, 8); + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [Mo,No,Ko,M,N,K] + tv1cw->merge(-2); + tv1r->merge(-2); + // [Mo,No,Ko,i,wy,wx,v] + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1cw, gemm_tile, 8); + scheduler_utils::matmul_utils::scheduleContiguousVectorLoad( + tv1r, gemm_tile, 8); + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + // --------------------------------------------------------------------------- + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + // --------------------------------------------------------------------------- + tv2c_rf->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + tv0b->computeAt(tv0cw, -2); + tv1b->computeAt(tv1cw, -2); + + tv0cr->axis(-1)->parallelize(ParallelType::Vectorize); + tv1cr->axis(-1)->parallelize(ParallelType::Vectorize); + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c_rf->axis(0)->parallelize(ParallelType::BIDx); + tv2c_rf->axis(1)->parallelize(ParallelType::BIDy); + tv2c_rf->axis(2)->parallelize(ParallelType::BIDz); + tv2c_rf->axis(4)->parallelize(ParallelType::TIDz); + tv2c_rf->axis(5)->parallelize(ParallelType::TIDy); + + tv2c->axis(0)->parallelize(ParallelType::BIDx); + tv2c->axis(1)->parallelize(ParallelType::BIDy); + tv2c->axis(2)->parallelize(ParallelType::BIDz); + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + auto tref = t0.to(at::kFloat).matmul(t1.to(at::kFloat).t()); + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + +// Test an end-to-end matmul case with swizzled smem +// data layout. +TEST_F(NVFuserTest, FusionAmpereMatmulTNSwizzled_CUDA) { + NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + + int M = 257, N = 511, K = 136; + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + // [M,K] + auto tv0 = makeContigTensor(2, DataType::Half); + // [N,K] + auto tv1 = makeContigTensor(2, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M,N,K] + auto tv0b = broadcast(tv0, {false, true, false}); + auto tv1b = broadcast(tv1, {true, false, false}); + + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Turing_16_8_16, gemm_tile) + .layout(MmaOptions::MmaInputLayout::TN); + + auto tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + + fusion.addOutput(tv2); + + mma_builder.configureMma(tv2); + + auto tv0cw = tv0->cacheAfter(LoadStoreOpType::CpAsync); + auto tv0cr = tv0cw->cacheAfter(LoadStoreOpType::LdMatrix); + auto tv1cw = tv1->cacheAfter(LoadStoreOpType::CpAsync); + auto tv1cr = tv1cw->cacheAfter(LoadStoreOpType::LdMatrix); + auto tv2c = tv2->cacheBefore(); + + mma_builder.accumulatorTv(tv2c); + + // Make a CTA tile + // ------------------------------------------------------------------ + // [M,N] + tv2->split(-2, gemm_tile.cta_tile.m); + tv2->split(-1, gemm_tile.cta_tile.n); + + // 0 1 2 3 + // [Mo,M128, No, N128] + tv2->reorder({{1, 2}, {2, 1}}); + + // 0 1 2 3 + // [Mo,No, M128, N128] + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Order K + // 0 1 2 3 4 5 + // [Mo,No, M128, N128, Ko, K32] + tv2c->split(-1, gemm_tile.cta_tile.k); + tv2c->reorder({{2, 3}, {3, 4}, {4, 2}}); + + // 0 1 2 3 4 5 + // [Mo,No, Ko M128, N128, K32] + tv0cw->computeAt(tv2c, 3); + tv1cw->computeAt(tv2c, 3); + + // Make warp tile: + // + scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(tv2c, gemm_tile); + scheduler_utils::matmul_utils::scheduleWarpTileWithNoReduction( + tv2, gemm_tile); + // -8 -7 -6 -5 -4 -3 -2 -1 + // [Mo No Ko Mwo Nwo Kwo Mw Nw Mi Ni Ki] + tv0cr->computeAt(tv2c, -4); + tv1cr->computeAt(tv2c, -4); + + // Schedule gmem read and smem write: + // + // [Mo,Ko,M,K] + // Swizzle tv0: 128 x 32 tile: + tv0cw->split(-2, 8); + tv0cw->split(-2, 2); + tv0cw->split(-1, 8); + // -5 -4 -3 -2 -1 + // [Mo,Ko,Mo16,M4,M2,Ko4,K8] + tv0cw->swizzle(Swizzle2DType::XOR, -4, -2); + tv0cw->merge(-4); + tv0cw->merge(-3); + // -3 -2 -1 + // [Mo,Ko,Mo16,warp,K8] + tv0cw->split(-3, 4); + tv0cw->split(-3, 2); + // -4 -3 -2 -1 + // [Mo,Ko, S4, wz2, wy2, warp,K8] + tv0cw->axis(-4)->parallelize(ParallelType::TIDz); + tv0cw->axis(-3)->parallelize(ParallelType::TIDy); + tv0cw->axis(-2)->parallelize(ParallelType::TIDx); + tv0cw->axis(-1)->parallelize(ParallelType::Vectorize); + + tv0cw->setMemoryType(MemoryType::Shared); + // [Mo,Ko,i,wy,wx,v] + + // [No,Ko,N,K] + // Swizzle tv0: 128 x 32 tile: + tv1cw->split(-2, 8); + tv1cw->split(-2, 2); + tv1cw->split(-1, 8); + // -5 -4 -3 -2 -1 + // [No,Ko,No16,N4,N2,Ko4,K8] + tv1cw->swizzle(Swizzle2DType::XOR, -4, -2); + tv1cw->merge(-4); + tv1cw->merge(-3); + // -3 -2 -1 + // [No,Ko,No16,warp,K8] + tv1cw->split(-3, 4); + tv1cw->split(-3, 2); + // -4 -3 -2 -1 + // [No,Ko, S4, wz2, wy2, warp,K8] + tv1cw->axis(-4)->parallelize(ParallelType::TIDz); + tv1cw->axis(-3)->parallelize(ParallelType::TIDy); + tv1cw->axis(-2)->parallelize(ParallelType::TIDx); + tv1cw->axis(-1)->parallelize(ParallelType::Vectorize); + + tv1cw->setMemoryType(MemoryType::Shared); + // Schedule mma input + tv0cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + // [... Mi, Ni, Ki] + tv0b->reorder({{-2, -3}, {-3, -2}}); + tv0b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build()); + + tv1cr->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + tv1b->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build()); + + // Schedule mma output + tv2c->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + tv2->applyMmaSwizzle( + mma_builder.operand(MmaOptions::Operand::Accumulator).build()); + + // Parallelize + // 0 1 2 3 4 5 6 7 8 9 10 + // [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)] + tv2c->axis(3)->parallelize(ParallelType::TIDz); + tv2c->axis(4)->parallelize(ParallelType::TIDy); + + // Parallelize + // 0 1 2 3 4 5 6 7 + // [Mo No Mwo Nwo Mw Nw (Mi Ni)] + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::TIDz); + tv2->axis(3)->parallelize(ParallelType::TIDy); + + tv0cw->doubleBuffer(); + tv1cw->doubleBuffer(); + tv0cr->doubleBuffer(); + tv1cr->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto t0 = at::randn({M, K}, options); + auto t1 = at::randn({N, K}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto tref = t0.to(at::kFloat).matmul(t1.t().to(at::kFloat)); + + TORCH_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + #undef NVFUSER_TEST_CUDA_ARCH_GUARD } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp index 71ec9c9916a46..32070b5aa8ed5 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp @@ -13,6 +13,7 @@ void ReplayTransformations::handle(Expr* e) { switch (e->getExprType().value()) { case (ExprType::Split): case (ExprType::Merge): + case (ExprType::Swizzle2D): break; default: TORCH_INTERNAL_ASSERT( @@ -130,13 +131,63 @@ void ReplayTransformations::handle(Merge* m) { id_map_[m->out()] = out; } +void ReplayTransformations::handle(Swizzle2D* swizzle_2d) { + // Grab our input to the split node + auto id_in_x = swizzle_2d->inX(); + auto id_in_y = swizzle_2d->inY(); + + // Make sure we have a corresponding entry in our map pointing to the ID we're + // going to replay the split on + auto it_x = id_map_.find(id_in_x); + auto it_y = id_map_.find(id_in_y); + + if (it_x == id_map_.end() || it_y == id_map_.end()) { + if (error_on_failure_) { + TORCH_INTERNAL_ASSERT( + false, "Transform traversal failed, dependencies not met."); + } else { + return; + } + } + + auto mapped_x = (*it_x).second; + auto mapped_y = (*it_y).second; + + // Make sure this ID is a leaf ID (meaning it has no uses we generated) + TORCH_INTERNAL_ASSERT( + leaf_ids_.find(mapped_x) != leaf_ids_.end() && + leaf_ids_.find(mapped_y) != leaf_ids_.end(), + "Transform traversal failed, modified a node but it was not a leaf node."); + + auto outs = std::make_pair(mapped_x, mapped_y); + + if (replay_swizzle_) { + // Replay the split onto mapped + outs = IterDomain::swizzle(swizzle_2d->swizzleType(), mapped_x, mapped_y); + + // Remove mapped from the leaf IDs + leaf_ids_.erase(mapped_x); + leaf_ids_.erase(mapped_y); + } + + // Add outputs to leaf IDs + leaf_ids_[outs.first] = counter++; + leaf_ids_[outs.second] = counter++; + + // Update our ID map to include these outputs + id_map_[swizzle_2d->outX()] = outs.first; + id_map_[swizzle_2d->outY()] = outs.second; +} + ReplayTransformations::ReplayTransformations( const std::vector& _target_domain, std::unordered_map _id_map, - bool _error_on_failure) + bool _error_on_failure, + bool replay_swizzle) : target_domain_(_target_domain), id_map_(std::move(_id_map)), - error_on_failure_(_error_on_failure) { + error_on_failure_(_error_on_failure), + replay_swizzle_(replay_swizzle) { // Make sure id_map has all the inputs needed to replay target_domain auto inps = IterVisitor::getInputsTo( std::vector(target_domain_.begin(), target_domain_.end())); @@ -221,10 +272,12 @@ BestEffortReplay::BestEffortReplay( const std::vector& target_domain, std::unordered_map target2replay_map, std::unordered_map replay_forward_id_map, - std::unordered_map target_forward_id_map) + std::unordered_map target_forward_id_map, + bool skip_swizzle) : target2replay_id_map_(std::move(target2replay_map)), replay_forward_id_map_(std::move(replay_forward_id_map)), - target_forward_id_map_(std::move(target_forward_id_map)) { + target_forward_id_map_(std::move(target_forward_id_map)), + skip_swizzle_(skip_swizzle) { for (auto entry : target2replay_id_map_) { leaf_ids_[entry.second] = counter++; } @@ -277,6 +330,22 @@ BestEffortReplay::BestEffortReplay( } } + std::unordered_map target_id2expr_map; + for (auto target_expr : target_exprs) { + for (auto id : ir_utils::filterByType(target_expr->inputs())) { + TORCH_INTERNAL_ASSERT( + target_id2expr_map.insert({id, target_expr}).second, + "BestEffortReplay : Unexpected multi-use of id", + id); + } + } + + if (skip_swizzle_) { + // Progress through all swizzle ops if we are skipping + // swizzles on the mapping. + skipSwizzles(target_id2expr_map, replay_id2expr_map); + } + std::string err_str( "Error during replay, a transformation was called that conflicts with an rfactor call."); @@ -451,6 +520,17 @@ BestEffortReplay::BestEffortReplay( } } + // Need to match swizzle type and parameters if + // not skipping swizzles in this mapping pass. + if (!skip_swizzle_ && replay_expr->etype() == ExprType::Swizzle2D) { + auto r_swizzle_2d = replay_expr->as(); + auto t_swizzle_2d = target_expr->as(); + if (!(r_swizzle_2d->swizzleType() == t_swizzle_2d->swizzleType())) { + TORCH_INTERNAL_ASSERT(!replay_has_rfactor_inp, err_str); + continue; + } + } + // Take replay expr inputs out of map: for (const auto t_i : c10::irange(target_id_inps.size())) { auto t_inp = target_id_inps[t_i]; @@ -479,6 +559,12 @@ BestEffortReplay::BestEffortReplay( leaf_ids_[r_out->as()] = counter++; } } + + if (skip_swizzle_) { + // Progress through all swizzle ops if we are skipping + // swizzles on the mapping. + skipSwizzles(target_id2expr_map, replay_id2expr_map); + } } } @@ -692,6 +778,64 @@ struct ProducerForwardingInfo { } }; +// Trace chain of swizzles until reaching +// an IterDomain that's either a leaf or +// not a producer of any swizzle. +IterDomain* getSwizzleFinalOutput( + IterDomain* id, + const std::unordered_map& id2expr) { + bool is_swizzle_input = true; + + // Note: currently not supporting swizzling consumer of another + // swizzle id, so this should terminate in 1 iter, but eventually + // will try to support stacked swizzles so keeping this pass + // generic. + while (is_swizzle_input) { + auto expr_it = id2expr.find(id); + + // This means id is a leaf that doesn't + // have any consumers. Stop iteration in this case. + if (expr_it == id2expr.end()) { + is_swizzle_input = false; + break; + } + + if (expr_it->second->etype() == ExprType::Swizzle2D) { + // In the case of 2D swizzle ops, just forward + // inX to outX and inY to outY. + auto expr = expr_it->second->as(); + if (id == expr->inX()) { + id = expr->outX(); + } else { + TORCH_INTERNAL_ASSERT( + id == expr->inY(), + "unknown input to swizzle op", + id->toString(), + expr->toString()); + id = expr->outY(); + } + } else { + // Probably unreachable but if the expression + // is unknown type assume it is not a swizzle op. + is_swizzle_input = false; + } + } + + return id; +} + +bool isSwizzleInput( + IterDomain* input_id, + const std::unordered_map& id2expr) { + auto user_expr_it = id2expr.find(input_id); + + if (user_expr_it == id2expr.end()) { + return false; + } + + return user_expr_it->second->etype() == ExprType::Swizzle2D; +} + } // namespace void BestEffortReplay::addComplimentLeafIDs( @@ -868,6 +1012,41 @@ BestEffortReplay BestEffortReplay::replayPasC( return producer_replay; } +void BestEffortReplay::skipSwizzles( + const std::unordered_map& target_id2expr, + const std::unordered_map& replay_id2expr) { + // Update target2replay map + bool updated = true; + + while (updated) { + updated = false; + for (auto it : target2replay_id_map_) { + if (isSwizzleInput(it.first, target_id2expr) || + isSwizzleInput(it.second, replay_id2expr)) { + updated = true; + auto new_target = getSwizzleFinalOutput(it.first, target_id2expr); + auto new_replay = getSwizzleFinalOutput(it.second, replay_id2expr); + + // new_target and new_replay will now be the final output + // skipping all swizzles in between. We'd need to + // update the mapping and leaf ids to the final outputs. + target2replay_id_map_.erase(it.first); + TORCH_INTERNAL_ASSERT( + target2replay_id_map_.insert(std::make_pair(new_target, new_replay)) + .second, + "Unexpected replay leaf"); + // Progress the leaf ids if the replay is updated + if (it.second != new_replay && + leaf_ids_.find(it.second) != leaf_ids_.end()) { + leaf_ids_.erase(it.second); + leaf_ids_[new_replay] = counter++; + } + break; + } + } + } +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.h b/torch/csrc/jit/codegen/cuda/transform_iter.h index a0d7346fe2dba..244483933a5a6 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.h +++ b/torch/csrc/jit/codegen/cuda/transform_iter.h @@ -56,6 +56,8 @@ class TORCH_CUDA_CU_API ReplayTransformations : public IterVisitor { bool error_on_failure_ = true; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) bool ran_replay = false; // Mark if replay has been run + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + bool replay_swizzle_ = false; using IterVisitor::handle; // Transform dispatch @@ -67,11 +69,25 @@ class TORCH_CUDA_CU_API ReplayTransformations : public IterVisitor { // We're going to replay this merge operation on the corresponding IDs void handle(Merge* m) override; + // We're going to replay this swizzle operation on the corresponding IDs + // if replaying swizzle is enabled. + void handle(Swizzle2D* m) override; + public: ReplayTransformations( const std::vector& _target_domain, std::unordered_map _id_map, - bool _error_on_failure = true); + bool _error_on_failure = true, + + // Indicates if we want to replay swizzle ops on the replayed + // tensor. + // The swizzle op will be replayed if true, + // The swizzle inputs will be directly forwarded, and therefore skipping + // the swizzle op if false. + // Currently this options should always be off but + // later we may have cases in scheduling large fusions where + // this functionality could be useful. + bool replay_swizzle = false); // Replays outputs that were generated from ids.first on ids.second void runReplay(); @@ -177,6 +193,43 @@ class TORCH_CUDA_CU_API BestEffortReplay { // deterministicly size_t counter = 0; + // Determine if current replay will ignore swizzle ops. + // When not skipping swizzles, swizzle ops will have to be matched + // same way as split and merge to progress forward on the mapping. + // + // When skipping swizzles, mismatched swizzle ops will not stop matching + // further down the tensor domains but only the swizzle outputs will be on + // the target to replay map, since we only generate one-to-one maps in + // BestEffortReplay and the swizzle outputs is just picked as a convention + // for simpler and uniform mapping behavior. The swizzle op inputs will be + // added by the disjoint set passes when building the iterdomain graph. + // + // Example: + // Target: + // I0o, I0i = split I0 + // Ix0o, Ix0i = swizzle I0o, I0i + // I02 = merge Ix0o, Ix0i + // Replay: + // I1o, I1i = split I1 + // I12 = merge I1o, I1i + // + // BestEffortReplay **no** skip swizzle gives: + // { + // I0->I1, + // I0o->I1o, + // I0i->I1i, + // } + // + // BestEffortReplay skip swizzle gives: + // { + // I0->I1, + // Ix0o->I1o, + // Ix0i->I1i, + // I02->I12 + // } + // + bool skip_swizzle_ = true; + bool inReplayForwardMap(IterDomain* id) const { return replay_forward_id_map_.find(id) != replay_forward_id_map_.end(); } @@ -209,13 +262,23 @@ class TORCH_CUDA_CU_API BestEffortReplay { const std::unordered_map>& compliment_map); + // Skip swizzle step to make sure both target and + // replay swizzles are skipped while the mapping + // makes progress. This makes sure that, for example + // different tensors can still be inlined despite + // different local swizzle patterns. + void skipSwizzles( + const std::unordered_map& target_id2expr, + const std::unordered_map& replay_id2expr); + public: BestEffortReplay( const std::vector& replay_domain, const std::vector& target_domain, std::unordered_map target2replay_map, std::unordered_map replay_forward_id_map = {}, - std::unordered_map target_forward_id_map = {}); + std::unordered_map target_forward_id_map = {}, + bool skip_swizzle = true); // Return iter domain map from target_domain IDs to their "replayed" // replay_domain IDs. If not in map, was not replayed. diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 2dc9212b3ffb6..f332e628a94b9 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -155,9 +155,7 @@ TensorDomain* TransformReplay::fullSelfReplay( size_t i = 0; for (auto id : self->getRootDomain()) { TORCH_INTERNAL_ASSERT( - new_self_root->getRootDomain()[i]->getParallelType() == - id->getParallelType() && - new_self_root->getRootDomain()[i]->isReduction() == + new_self_root->getRootDomain()[i]->isReduction() == id->isReduction() && new_self_root->getRootDomain()[i]->isRFactorProduct() == id->isRFactorProduct() && @@ -223,7 +221,8 @@ std::pair TransformReplay::replayPasC( const TensorView* producer, const TensorView* consumer, int consumer_compute_at_axis, - const RootDomainMap& root_map) { + const RootDomainMap& root_map, + bool replay_swizzle) { FUSER_PERF_SCOPE("TransformReplay::replayPasC"); // If this is a reduction operation, we may call transform_replay on the @@ -262,7 +261,7 @@ std::pair TransformReplay::replayPasC( // Replay producer dimensions. ReplayTransformations replay_PasC( - consumer_CA_ids, forwarded_replay_map, false); + consumer_CA_ids, forwarded_replay_map, false, replay_swizzle); auto leaf_ids(replay_PasC.getUnorderedLeafIDs()); @@ -631,10 +630,12 @@ std::pair TransformReplay::replayCasP( std::pair TransformReplay::replayPasC( const TensorView* producer, const TensorView* consumer, - int compute_at_axis) { + int compute_at_axis, + bool replay_swizzle) { // Use the pairwise root map as a default mapper PairwiseRootDomainMap root_map(producer, consumer); - return replayPasC(producer, consumer, compute_at_axis, root_map); + return replayPasC( + producer, consumer, compute_at_axis, root_map, replay_swizzle); } std::pair TransformReplay::replayCasP( @@ -646,18 +647,305 @@ std::pair TransformReplay::replayCasP( return replayCasP(consumer, producer, compute_at_axis, root_map); } -void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) { +// In a PasC replay, we want the producer to exactly match the consumer: +// all the beginning axes in the producer should be mapped to the consumer in +// the same order. Reductions in the producer needs to be in the back of the +// producer. +int TransformReplay::getMatchedLeafPosWithoutReplayPasC( + const TensorView* producer, + const TensorView* consumer, + int consumer_pos) { + FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplayPasC"); + + const auto pairwise_map = PairwiseRootDomainMap(producer, consumer); + id_map c2p_root_map = pairwise_map.mapConsumerToProducer( + consumer->domain(), producer->domain()); + + // IterDomains in `consumer` root also in `producer` root + const auto consumer_domain = consumer->domain()->domain(); + + std::unordered_set mapped_consumer_roots; + for (auto entry : c2p_root_map) { + mapped_consumer_roots.emplace(entry.first); + } + + auto unskippable_consumer_ids_vec = DependencyCheck::getAllValsBetween( + mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); + + std::unordered_set unskippable_consumer_ids( + unskippable_consumer_ids_vec.begin(), unskippable_consumer_ids_vec.end()); + + // IterDomains in `producer` root also in `consumer` root + const auto producer_domain = producer->domain()->domain(); + + auto it_consumer = consumer_domain.begin(); + auto it_producer = producer_domain.begin(); + + id_map c2p_map = + BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) + .getReplay(); + + int mismatched_consumer_pos = 0; + int mismatched_producer_pos = 0; + while (it_consumer != consumer_domain.end()) { + if (consumer_pos == mismatched_consumer_pos) { + return mismatched_producer_pos; + } + + auto consumer_id = *it_consumer; + if (unskippable_consumer_ids.count(consumer_id) == 0) { + ++it_consumer; + ++mismatched_consumer_pos; + continue; + } + + if (it_producer == producer_domain.end()) { + return -1; + } + + auto c2p_it = c2p_map.find(consumer_id); + if (c2p_it == c2p_map.end()) { + return -1; + } + + auto producer_id = *it_producer; + if (c2p_it->second == producer_id) { + ++mismatched_consumer_pos; + ++mismatched_producer_pos; + ++it_consumer; + ++it_producer; + } else { + return -1; + } + } + if (consumer_pos == mismatched_consumer_pos) { + return mismatched_producer_pos; + } + return -1; +} + +// We want to ignore reductions in the producer in a CasP replay. +int TransformReplay::getMatchedLeafPosWithoutReplayCasP( + const TensorView* consumer, + const TensorView* producer, + int producer_pos) { + FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplayCasP"); + + const auto pairwise_map = PairwiseRootDomainMap(producer, consumer); + id_map p2c_root_map = pairwise_map.mapProducerToConsumer( + producer->domain(), consumer->domain()); + + // IterDomains in `producer` root that are not reduction + const auto producer_domain = producer->domain()->domain(); + auto unskippable_producer_ids_vec = + TensorDomain::noReductions(producer_domain); + std::unordered_set unskippable_producer_ids( + unskippable_producer_ids_vec.begin(), unskippable_producer_ids_vec.end()); + + // IterDomains in `consumer` root also in `producer` root + const auto consumer_domain = consumer->domain()->domain(); + + std::unordered_set mapped_consumer_roots; + for (auto entry : p2c_root_map) { + mapped_consumer_roots.emplace(entry.second); + } + + auto unskippable_consumer_ids_vec = DependencyCheck::getAllValsBetween( + mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()}); + + std::unordered_set unskippable_consumer_ids( + unskippable_consumer_ids_vec.begin(), unskippable_consumer_ids_vec.end()); + + auto it_producer = producer_domain.begin(); + auto it_consumer = consumer_domain.begin(); + + id_map replay_map = + BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_map) + .getReplay(); + + int mismatched_producer_pos = 0; + int mismatched_consumer_pos = 0; + while (it_producer != producer_domain.end()) { + if (producer_pos == mismatched_producer_pos) { + return mismatched_consumer_pos; + } + + auto producer_id = *it_producer; + if (unskippable_producer_ids.count(producer_id) == 0) { + ++it_producer; + ++mismatched_producer_pos; + continue; + } + + if (it_consumer == consumer_domain.end()) { + return -1; + } + + auto consumer_id = *it_consumer; + if (unskippable_consumer_ids.count(consumer_id) == 0) { + ++it_consumer; + ++mismatched_consumer_pos; + continue; + } + + auto replay_it = replay_map.find(producer_id); + if (replay_it == replay_map.end()) { + return -1; + } + + if (replay_it->second == consumer_id) { + ++mismatched_producer_pos; + ++mismatched_consumer_pos; + ++it_producer; + ++it_consumer; + } else { + return -1; + } + } + if (producer_pos == mismatched_producer_pos) { + return mismatched_consumer_pos; + } + return -1; +} + +bool TransformReplay::fullSelfMatching( + const TensorView* replay, + const TensorView* target) { + auto replay_root = replay->getRootDomain(); + auto replay_dom = replay->domain()->domain(); + auto target_root = target->getRootDomain(); + auto target_dom = target->domain()->domain(); + std::unordered_map target2replay_map; + if (replay_root.size() != target_root.size()) { + return false; + } + target2replay_map.reserve(replay_root.size()); + std::transform( + target_root.begin(), + target_root.end(), + replay_root.begin(), + std::inserter(target2replay_map, target2replay_map.begin()), + [](auto a, auto b) { return std::make_pair(a, b); }); + BestEffortReplay replay_(replay_dom, target_dom, target2replay_map); + auto r = replay_.getReplay(); + for (int64_t i = 0; i < replay_dom.size(); i++) { + auto target_id = target_dom[i]; + auto replay_it = r.find(target_id); + if (replay_it == r.end() || replay_it->second != replay_dom[i]) { + return false; + } + } + return true; +} + +namespace { + +// Make sure if tv is set to new_td it doesn't violate set compute at and max +// produce at positions. +bool validateDomain(TensorView* tv, TensorDomain* new_td) { + auto first_mismatch = + BestEffortReplay::findFirstMismatchedID(tv->domain(), new_td); + return first_mismatch >= (int)tv->getMaxProducerPosition() && + first_mismatch >= (int)tv->getComputeAtPosition(); +} + +} // namespace + +void TransformPropagator::propagateC2P(TensorView* from, TensorView* to) { int pos = replayed_pos_.at(from); - auto replay = TransformReplay::replayPasC(to, from, pos); - to->setDomain(replay.first); - replayed_pos_[to] = replay.second; + // Note: [Using multiple TransformPropagators] + // There are cases that we use multiple TransformPropagators along different + // spanning trees with different references in the same fusion. Some of these + // spanning trees could overlap. In cases when there are overlapping nodes, + // TransformPropagator needs to respect the replay of others, because the + // current TransformPropagator might not contain the most amount of + // information on how to do the correct transformation. The logic below tells + // TransformPropagator to skip the replay when not necessary. + int new_pos = + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); + bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); + if (debug) { + std::cout << "TransformPropagator::propagateC2P" << std::endl; + std::cout << " from: " << from << " @ " << pos << std::endl; + std::cout << " to: " << to << std::endl; + } + if (new_pos < 0) { + auto replay = TransformReplay::replayPasC(to, from, pos); + TORCH_INTERNAL_ASSERT( + validateDomain(to, replay.first), + "Tried to set the domain of ", + to, + " to ", + replay.first, + " but that would invalidate previously compute at position or max producer position."); + to->setDomain(replay.first); + new_pos = replay.second; + if (debug) { + std::cout << " replayed: " << to << " @ " << new_pos << std::endl; + } + } else if (debug) { + std::cout << " replay skipped. result position: " << new_pos << std::endl; + } + replayed_pos_[to] = new_pos; } -void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) { +void TransformPropagator::propagateP2C(TensorView* from, TensorView* to) { int pos = replayed_pos_.at(from); - auto replay = TransformReplay::replayCasP(to, from, pos); - to->setDomain(replay.first); - replayed_pos_[to] = replay.second; + // See note [Using multiple TransformPropagators] + int new_pos = + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); + bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); + if (debug) { + std::cout << "TransformPropagator::propagateP2C" << std::endl; + std::cout << " from: " << from << " @ " << pos << std::endl; + std::cout << " to: " << to << std::endl; + } + if (new_pos < 0) { + auto replay = TransformReplay::replayCasP(to, from, pos); + TORCH_INTERNAL_ASSERT( + validateDomain(to, replay.first), + "Tried to set the domain of ", + to, + " to ", + replay.first, + " but that would invalidate previously compute at position or max producer position."); + to->setDomain(replay.first); + new_pos = replay.second; + if (debug) { + std::cout << " replayed: " << to << " @ " << new_pos << std::endl; + } + } else if (debug) { + std::cout << " replay skipped. result position: " << new_pos << std::endl; + } + replayed_pos_[to] = new_pos; +} + +void TransformPropagator::propagateSibling(TensorView* from, TensorView* to) { + int pos = replayed_pos_.at(from); + // See note [Using multiple TransformPropagators] + bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); + if (debug) { + std::cout << "TransformPropagator::propagateSibling" << std::endl; + std::cout << " from: " << from << " @ " << pos << std::endl; + std::cout << " to: " << to << std::endl; + } + if (!TransformReplay::fullSelfMatching(to, from)) { + auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); + TORCH_INTERNAL_ASSERT( + validateDomain(to, replay), + "Tried to set the domain of ", + to, + " to ", + replay, + " but that would invalidate previously compute at position or max producer position."); + to->setDomain(replay); + if (debug) { + std::cout << " replayed: " << to << " @ " << pos << std::endl; + } + } else if (debug) { + std::cout << " replay skipped. result position: " << pos << std::endl; + } + replayed_pos_[to] = pos; } TransformPropagator::TransformPropagator(TensorView* from, int64_t pos) { @@ -670,6 +958,97 @@ TransformPropagator::TransformPropagator(TensorView* from, int64_t pos) { replayed_pos_[from] = pos; } +void MostInlinedTransformPropagator::propagateC2P( + TensorView* from, + TensorView* to) { + int pos = from->nDims(); + // See note [Using multiple TransformPropagators] + int new_pos = + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos); + bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); + if (debug) { + std::cout << "MostInlinedTransformPropagator::propagateC2P" << std::endl; + std::cout << " from: " << from << std::endl; + std::cout << " to: " << to << std::endl; + } + if (new_pos < 0) { + auto replay = TransformReplay::replayPasC(to, from, pos); + TORCH_INTERNAL_ASSERT( + validateDomain(to, replay.first), + "Tried to set the domain of ", + to, + " to ", + replay.first, + " but that would invalidate previously compute at position or max producer position."); + to->setDomain(replay.first); + if (debug) { + std::cout << " replayed: " << to << std::endl; + } + } else if (debug) { + std::cout << " replay skipped" << std::endl; + } +} + +void MostInlinedTransformPropagator::propagateP2C( + TensorView* from, + TensorView* to) { + int pos = from->nDims(); + // See note [Using multiple TransformPropagators] + int new_pos = + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos); + bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); + if (debug) { + std::cout << "MostInlinedTransformPropagator::propagateP2C" << std::endl; + std::cout << " from: " << from << std::endl; + std::cout << " to: " << to << std::endl; + } + if (new_pos < 0) { + auto replay = TransformReplay::replayCasP(to, from, pos); + TORCH_INTERNAL_ASSERT( + validateDomain(to, replay.first), + "Tried to set the domain of ", + to, + " to ", + replay.first, + " but that would invalidate previously compute at position or max producer position."); + to->setDomain(replay.first); + if (debug) { + std::cout << " replayed: " << to << std::endl; + } + } else if (debug) { + std::cout << " replay skipped" << std::endl; + } +} + +void MostInlinedTransformPropagator::propagateSibling( + TensorView* from, + TensorView* to) { + // See note [Using multiple TransformPropagators] + bool debug = isDebugDumpEnabled(DebugDumpOption::TransformPropagator); + if (debug) { + std::cout << "MostInlinedTransformPropagator::propagateSibling" + << std::endl; + std::cout << " from: " << from << std::endl; + std::cout << " to: " << to << std::endl; + } + if (!TransformReplay::fullSelfMatching(to, from)) { + auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain()); + TORCH_INTERNAL_ASSERT( + validateDomain(to, replay), + "Tried to set the domain of ", + to, + " to ", + replay, + " but that would invalidate previously compute at position or max producer position."); + to->setDomain(replay); + if (debug) { + std::cout << " replayed: " << to << std::endl; + } + } else if (debug) { + std::cout << " replay skipped" << std::endl; + } +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index d45454e149b63..3dace83adab75 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -131,12 +131,14 @@ class TORCH_CUDA_CU_API TransformReplay { static std::pair replayPasC( const TensorView* producer, const TensorView* consumer, - int consumer_compute_at_axis); + int consumer_compute_at_axis, + bool replay_swizzle = false); static std::pair replayPasC( const TensorView* producer, const TensorView* consumer, int consumer_compute_at_axis, - const RootDomainMap& root_map); + const RootDomainMap& root_map, + bool replay_swizzle = false); // Replay producer as consumer, returns {replayed_consumer_domain, // consumer_compute_at_axis}. @@ -154,18 +156,53 @@ class TORCH_CUDA_CU_API TransformReplay { static TensorDomain* fullSelfReplay( const TensorDomain* new_self_root, const TensorDomain* self); + + // Returns the leaf position in producer that matches with `consumer_pos` in + // consumer. Returns -1 if matching is impossible. This function can be used + // to test if replay is needed for getting matching outer dims. This function + // should be consistent with `replayPasC`: if you pass the tensors just + // replayed by replayPasC as inputs, you should return exactly the same + // position as `replayPasC`. However, this function is more tolerant than + // fully matching `replayPasC`: if in the consumer, there are unmappable + // dimensions, these dimensions are just ignored. + static int getMatchedLeafPosWithoutReplayPasC( + const TensorView* producer, + const TensorView* consumer, + int consumer_pos); + + // Returns the leaf position in consumer that matches with `producer_pos` in + // producer. Behavior similar to getMatchedLeafPosWithoutReplayPasC, except + // that we are also ignoring reductions in the producer. + static int getMatchedLeafPosWithoutReplayCasP( + const TensorView* consumer, + const TensorView* producer, + int producer_pos); + + // tests if two tensors has fully matching transformations + static bool fullSelfMatching( + const TensorView* replay, + const TensorView* target); }; class TORCH_CUDA_CU_API TransformPropagator : public MaxRootDomainInfoSpanningTree::Propagator { + protected: std::unordered_map replayed_pos_; public: - virtual void propagateTvPasC(TensorView* from, TensorView* to) override; - virtual void propagateTvCasP(TensorView* from, TensorView* to) override; + virtual void propagateC2P(TensorView* from, TensorView* to) override; + virtual void propagateP2C(TensorView* from, TensorView* to) override; + virtual void propagateSibling(TensorView* from, TensorView* to) override; TransformPropagator(TensorView* from, int64_t pos = -1); }; +struct TORCH_CUDA_CU_API MostInlinedTransformPropagator + : public MaxRootDomainInfoSpanningTree::Propagator { + virtual void propagateC2P(TensorView* from, TensorView* to) override; + virtual void propagateP2C(TensorView* from, TensorView* to) override; + virtual void propagateSibling(TensorView* from, TensorView* to) override; +}; + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 2359a511ceecf..5caa5b3f96716 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -21,13 +21,12 @@ DataType indexModeToDtype(KernelIndexMode index_mode) { bool isFloatingPointType(DataType dtype) { switch (dtype) { - case DataType::Bool: - return false; case DataType::Double: case DataType::Float: case DataType::Half: case DataType::BFloat16: return true; + case DataType::Bool: case DataType::Index: case DataType::Int: case DataType::Int32: @@ -78,10 +77,9 @@ bool isIntegralType(DataType dtype) { case DataType::Int32: return true; case DataType::Null: - TORCH_CHECK( - false, "Null type is not a valid argument to isFloatingPoint"); + TORCH_CHECK(false, "Null type is not a valid argument to isIntegralType"); default: - TORCH_CHECK(false, "Type not supported in isFloatingPoint"); + TORCH_CHECK(false, "Type not supported in isIntegralType"); } } @@ -348,6 +346,12 @@ static const char* expr_type2string(ExprType t) { return "GridBroadcast"; case ExprType::GridWelford: return "GridWelford"; + case ExprType::Swizzle2D: + return "Swizzle2D"; + case ExprType::Swizzle2DInt: + return "Swizzle2DInt"; + case ExprType::PairSelect: + return "PairSelect"; default: TORCH_INTERNAL_ASSERT(false, "No string found for expr type."); } @@ -685,6 +689,8 @@ static const char* parallel_type2string(ParallelType t) { return "US"; case ParallelType::Mma: return "MMA"; + case ParallelType::Group: + return "G"; case ParallelType::Serial: return "S"; default: @@ -973,6 +979,32 @@ TORCH_CUDA_CU_API std::ostream& operator<<( return out << iter_type2string(bt); } +TORCH_CUDA_CU_API std::ostream& operator<<( + std::ostream& os, + const Swizzle2DType& swizzle) { + switch (swizzle) { + case Swizzle2DType::NoSwizzle: + os << "NoSwizzle"; + break; + case Swizzle2DType::ZShape: + os << "ZShape"; + break; + case Swizzle2DType::Transpose: + os << "Transpose"; + break; + case Swizzle2DType::XOR: + os << "Xor"; + break; + case Swizzle2DType::Scatter: + os << "Scatter"; + break; + default: + TORCH_INTERNAL_ASSERT(false, "undefined 2D swizzle"); + break; + } + return os; +} + TORCH_CUDA_CU_API c10::optional inline_op_str( const UnaryOpType uotype) { const char* str = unary_op_type_inline_op2string(uotype); diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index d9edea262c812..f31b77bf2e06f 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -33,6 +33,7 @@ enum class ValType { NamedScalar, Predicate, TensorIndex, + IntPair }; // Manual - The user provides the Bool value. Predicate generation is bypassed. @@ -119,6 +120,9 @@ enum class ExprType { Split, ViewAsScalar, Merge, + Swizzle2D, + Swizzle2DInt, + PairSelect, Allocate, BlockSync, GridSync, @@ -255,6 +259,7 @@ enum class ParallelType { Unroll, Unswitch, Mma, + Group, Serial }; @@ -313,6 +318,14 @@ enum class LoadStoreOpType { LdMatrix, LdMatrixTranspose, CpAsync }; // a for loop is materializing. enum class DoubleBufferLoopStage { NotApplicable, Prolog, Main, Epilog }; +//! Supported swizzle types, +//! corresponds to swizzles functions on the runtime cuda +//! naming it swizzle_2d to reserve the options to have a swizzle_1d. +//! +//! TODO: unify with existing swizzle logic, currently +//! doesn't have the same type. +enum class Swizzle2DType { NoSwizzle = 0, ZShape, Transpose, XOR, Scatter }; + // Returns if function needs an f suffix on the operator when operating on a // float value i.e. sin->sinf bool needFloatSuffix(UnaryOpType t); @@ -343,6 +356,7 @@ TORCH_CUDA_CU_API std::ostream& operator<<( TORCH_CUDA_CU_API std::ostream& operator<<( std::ostream&, const DoubleBufferLoopStage); +TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const Swizzle2DType&); std::string stringifyBooleanOp(const UnaryOpType); std::string stringifyBooleanOp(const BinaryOpType); @@ -387,6 +401,10 @@ enum class LaunchConfigType { const char* const kMagicZeroName = "nvfuser_zero"; +//! Maximum number of reductions that can be grouped together. The +//! limit can be increased by extending struct Tuple define in tuple.cu. +static constexpr int kMaxNumGroupedReductions = 8; + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/type_inference.cpp b/torch/csrc/jit/codegen/cuda/type_inference.cpp index 58f2187ea1cc9..0bc821e024c3d 100644 --- a/torch/csrc/jit/codegen/cuda/type_inference.cpp +++ b/torch/csrc/jit/codegen/cuda/type_inference.cpp @@ -462,12 +462,20 @@ class NaiveTypePropagator { type0->withScalarType(type1->scalarType()), node); break; } - case aten::to: { + case aten::to: + case aten::_to_copy: { const auto type0 = getInputTensorType(node, 0); const auto out_dtype = toIValue(node->input(1)); - TORCH_CHECK(out_dtype, "No output type specified"); - copyScalarTypeAndDeviceToOutput( - type0->withScalarType(out_dtype->toScalarType()), node); + if (out_dtype.has_value() && out_dtype->isInt()) { + copyScalarTypeAndDeviceToOutput( + type0->withScalarType(out_dtype->toScalarType()), node); + } else { + TORCH_CHECK( + !out_dtype.has_value() || out_dtype->isNone(), + "dtype for cast unrecognized ", + out_dtype->tagKind()); + copyScalarTypeAndDeviceToOutput(type0, node); + } break; } case prim::add_optional: { diff --git a/torch/csrc/jit/codegen/cuda/utils.cpp b/torch/csrc/jit/codegen/cuda/utils.cpp index 76347177baa9d..80c8c39f6ce9d 100644 --- a/torch/csrc/jit/codegen/cuda/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/utils.cpp @@ -35,7 +35,8 @@ auto parseDebugDumpOptions() { {DebugDumpOption::SchedulerDebug, false}, {DebugDumpOption::ParallelDimensions, false}, {DebugDumpOption::Halo, false}, - {DebugDumpOption::PerfDebugVerbose, false}}; + {DebugDumpOption::PerfDebugVerbose, false}, + {DebugDumpOption::TransformPropagator, false}}; if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_DUMP")) { c10::string_view options_view(dump_options); @@ -82,6 +83,8 @@ auto parseDebugDumpOptions() { options_map[DebugDumpOption::Halo] = true; } else if (token == "perf_debug_verbose") { options_map[DebugDumpOption::PerfDebugVerbose] = true; + } else if (token == "transform_propagator") { + options_map[DebugDumpOption::TransformPropagator] = true; } else { TORCH_CHECK( false, @@ -123,6 +126,8 @@ auto parseDisableOptions() { } else if (token == "fallback") { options_map[DisableOption::Fallback] = true; } else if (token == "fma") { + TORCH_WARN( + "fmad is disabled for nvrtc, which could negatively affect performance. Try removing `fma` from env variable PYTORCH_NVFUSER_DISABLE for optimal performance."); options_map[DisableOption::Fma] = true; } else if (token == "index_hoist") { options_map[DisableOption::IndexHoist] = true; diff --git a/torch/csrc/jit/codegen/cuda/utils.h b/torch/csrc/jit/codegen/cuda/utils.h index 60b0acc196a2a..7357e67edfbc4 100644 --- a/torch/csrc/jit/codegen/cuda/utils.h +++ b/torch/csrc/jit/codegen/cuda/utils.h @@ -42,8 +42,10 @@ enum class DebugDumpOption { SchedulerDebug, //! Dump scheduler heuristic parameters ParallelDimensions, //!< Dump known parallel dimensions Halo, //! Halo information of tensors - PerfDebugVerbose //! When running kernels, print verbose information - //! associated with what's running + PerfDebugVerbose, //! When running kernels, print verbose information + //! associated with what's running + TransformPropagator, //! When running TransformPropagator, print propagation + //! path and replay result }; TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option);