Skip to content

Commit

Permalink
Use scheduler_utils to cache inputs and outputs in schedulePointwise (c…
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jul 13, 2022
1 parent 03180aa commit 3df9742
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 42 deletions.
5 changes: 2 additions & 3 deletions torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -983,9 +983,8 @@ TORCH_CUDA_CU_API void schedulePersistentKernel(
// Cache inputs if unrolled
auto cached_inputs = scheduler_utils::cacheInputs(fusion, unroll);

// Cache and fork outputs
std::vector<std::pair<TensorView*, TensorView*>> cached_outputs =
scheduler_utils::cacheAndForkOutputs(fusion, unroll);
// Cache and fork outputs
auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, unroll);

// Make sure we don't have global memory set on intermediate tensors from
// fusion segmentation
Expand Down
43 changes: 7 additions & 36 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,15 +592,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
ir_utils::getReductionOps(fusion /*, ignore_trivial=true */).empty(),
"This scheduler only handles pointwise ops.");

// For intermediate outputs, apply cacheFork
auto outs = fusion->outputs();
for (const auto output : outs) {
if (!output->uses().empty() && output->definition() != nullptr) {
if (output->getValType().value() == ValType::TensorView) {
output->as<TensorView>()->cacheFork();
}
}
}
// Cache inputs
auto cached_inputs = scheduler_utils::cacheInputs(fusion, true);

// Cache and fork outputs
auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true);

std::vector<TensorView*> input_tvs;
{
Expand Down Expand Up @@ -637,31 +633,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
reference_tv != nullptr,
"Could not find a fully broadcasted output to reference schedule on.");

// Caches of inputs
std::vector<TensorView*> cached_inputs;

// Output, cacheBefore of output
std::vector<std::pair<TensorView*, TensorView*>> cached_outputs;

// Track what should be vectorized versus unrolled
std::unordered_set<TensorView*> vectorized_tensor;

// Figure out which inputs to cache for unrolling or vectorization
for (auto inp : input_tvs) {
if (inp->uses().empty() || inp->isFusionOutput()) {
continue;
}
cached_inputs.emplace_back(inp->cacheAfter());
}

// Figure out which outputs to cache for unrolling or vectorization
for (auto out : output_tvs) {
if (out->definition() == nullptr) {
continue;
}
cached_outputs.emplace_back(std::make_pair(out, out->cacheBefore()));
}

auto all_tvs = ir_utils::allTvs(fusion);

// Merge right side of break point
Expand Down Expand Up @@ -929,8 +900,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
// Compute at cached outputs
//[BIDx, Unswitch, Vectorization, TIDx]
for (auto entry : cached_outputs) {
auto cached_output = entry.second;
auto output = entry.first;
auto cached_output = entry.first;
auto output = entry.second;

auto unswitch_it = std::find_if(
output->domain()->domain().begin(),
Expand Down
5 changes: 2 additions & 3 deletions torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1002,9 +1002,8 @@ void scheduleReduction(Fusion* fusion, const ReductionParams& rparams) {
// Cache inputs if unrolled
auto cached_inputs = scheduler_utils::cacheInputs(fusion, unroll);

// Cache and fork outputs
std::vector<std::pair<TensorView*, TensorView*>> cached_outputs =
scheduler_utils::cacheAndForkOutputs(fusion, unroll);
// Cache and fork outputs
auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, unroll);

// Make sure we don't have global memory set on intermediate tensors from
// fusion segmentation
Expand Down

0 comments on commit 3df9742

Please sign in to comment.