diff --git a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp index 6f3b736579d93..656db1c0ed805 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp @@ -983,9 +983,8 @@ TORCH_CUDA_CU_API void schedulePersistentKernel( // Cache inputs if unrolled auto cached_inputs = scheduler_utils::cacheInputs(fusion, unroll); - // Cache and fork outputs - std::vector> cached_outputs = - scheduler_utils::cacheAndForkOutputs(fusion, unroll); + // Cache and fork outputs + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, unroll); // Make sure we don't have global memory set on intermediate tensors from // fusion segmentation diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 32943be72d440..2f00863776d7f 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -592,15 +592,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { ir_utils::getReductionOps(fusion /*, ignore_trivial=true */).empty(), "This scheduler only handles pointwise ops."); - // For intermediate outputs, apply cacheFork - auto outs = fusion->outputs(); - for (const auto output : outs) { - if (!output->uses().empty() && output->definition() != nullptr) { - if (output->getValType().value() == ValType::TensorView) { - output->as()->cacheFork(); - } - } - } + // Cache inputs + auto cached_inputs = scheduler_utils::cacheInputs(fusion, true); + + // Cache and fork outputs + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true); std::vector input_tvs; { @@ -637,31 +633,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { reference_tv != nullptr, "Could not find a fully broadcasted output to reference schedule on."); - // Caches of inputs - std::vector cached_inputs; - - // Output, cacheBefore of output - std::vector> cached_outputs; - - // Track what should be vectorized versus unrolled - std::unordered_set vectorized_tensor; - - // Figure out which inputs to cache for unrolling or vectorization - for (auto inp : input_tvs) { - if (inp->uses().empty() || inp->isFusionOutput()) { - continue; - } - cached_inputs.emplace_back(inp->cacheAfter()); - } - - // Figure out which outputs to cache for unrolling or vectorization - for (auto out : output_tvs) { - if (out->definition() == nullptr) { - continue; - } - cached_outputs.emplace_back(std::make_pair(out, out->cacheBefore())); - } - auto all_tvs = ir_utils::allTvs(fusion); // Merge right side of break point @@ -929,8 +900,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // Compute at cached outputs //[BIDx, Unswitch, Vectorization, TIDx] for (auto entry : cached_outputs) { - auto cached_output = entry.second; - auto output = entry.first; + auto cached_output = entry.first; + auto output = entry.second; auto unswitch_it = std::find_if( output->domain()->domain().begin(), diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp index 1696242f4ff28..84a78bcf927d1 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp @@ -1002,9 +1002,8 @@ void scheduleReduction(Fusion* fusion, const ReductionParams& rparams) { // Cache inputs if unrolled auto cached_inputs = scheduler_utils::cacheInputs(fusion, unroll); - // Cache and fork outputs - std::vector> cached_outputs = - scheduler_utils::cacheAndForkOutputs(fusion, unroll); + // Cache and fork outputs + auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, unroll); // Make sure we don't have global memory set on intermediate tensors from // fusion segmentation