Skip to content

Commit

Permalink
Enable output allocation cache (#2010)
Browse files Browse the repository at this point in the history
Fixes #2002

checks all IterDomain on outputs and disables verifies that no extent value is a consumer of fusion inputs.
  • Loading branch information
jjsjann123 authored Sep 30, 2022
1 parent 35440b7 commit a4effa6
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 2 deletions.
29 changes: 29 additions & 0 deletions torch/csrc/jit/codegen/cuda/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,35 @@ void FusionExecutor::compileFusion(
TORCH_INTERNAL_ASSERT(
out->getValType() == ValType::TensorView,
"Output types from fusions that are not tensors are not supported at this point.");

const auto maybe_rfactor_domain =
out->as<TensorView>()->getMaybeRFactorDomain();
// walking through outputs to see if output shapes are dependent on
// non-tensor inputs. For which case, we should have disabled output
// allocation, since the caching id only looks at tensor shapes.
// See issue https://github.com/csarofeen/pytorch/issues/2002
std::vector<Val*> output_extents;
for (const auto id : maybe_rfactor_domain) {
Val* extent = nullptr;
if (id->isReduction() || id->isStride()) {
continue;
} else if (id->isBroadcast() && id->hasExpandedExtent()) {
extent = id->expandedExtent();
} else {
extent = id->extent();
}
output_extents.emplace_back(extent);
}
auto dependencies = InputsOf::outputs(fusion, output_extents);
if (std::any_of(dependencies.begin(), dependencies.end(), [](Val* val) {
return val->isFusionInput();
})) {
// TODO: parameter cache is too big a hammer here. We should consider
// separate the caching logic of output sizes & launch params. Since
// output size dependency should only invalidate the output sizes
disable_parameter_cache_ = true;
break;
}
}

if (isDebugDumpEnabled(DebugDumpOption::FusionIr)) {
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/codegen/cuda/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,10 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
// Profiling support: the last launch param used
LaunchParams launch_params_;

// Profiling support: knob to disable caching of launch params
// Profiling support: disable caching of launch params and output allocation
// output allocation is also disable when output sizes are dependent on
// runtime scalar inputs, such as for the case of tensor factory. see
// https://github.com/csarofeen/pytorch/issues/2002
bool disable_parameter_cache_ = false;

// Profiling support: kept copy of the cuda kernel
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/codegen/cuda/executor_kernel_arg.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,9 @@ class TORCH_CUDA_CU_API KernelArgumentHolder {
: index_mode_(index_mode) {}

KernelArgumentHolder(const KernelArgumentHolder& self)
: device_index_(self.getDeviceIndex()), index_mode_(self.getIndexMode()) {
: device_index_(self.getDeviceIndex()),
cache_id_(self.getCacheId()),
index_mode_(self.getIndexMode()) {
for (const auto& arg : self.arguments_) {
push(arg.get());
}
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,11 +621,16 @@ std::vector<at::Tensor> FusionKernelRuntime::runWithInput(
<< std::endl;
}

// group should share cache id.
auto group_cache_id = args.getCacheId();
for (auto group_to_run : runtime_workspace_.group_run_order) {
// TODO: index mode should be updated per segmented kernel
// Prepare input vector
KernelArgumentHolder group_runtime_inputs(args.getIndexMode());
group_runtime_inputs.setDeviceIndex(args.getDeviceIndex());
if (group_cache_id.has_value()) {
group_runtime_inputs.setCacheId(group_cache_id.value());
}
for (auto input : group_to_run->inputs()) {
group_runtime_inputs.push(tensor_map.at(input));
}
Expand Down

0 comments on commit a4effa6

Please sign in to comment.