Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make resolved broadcast tensors non-persistent #2563

Open
wants to merge 18 commits into
base: devel
Choose a base branch
from
5 changes: 4 additions & 1 deletion third_party/nvfuser/csrc/scheduler/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,8 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) {
std::vector<TensorView*> dummy_outputs;
if (rparams.project_persistent_buffers &&
ir_utils::getViewOps(fusion).empty()) {
dummy_outputs = reduction_scheduler_utils::projectPersistentBuffers(fusion);
dummy_outputs =
normalization_scheduler_utils::projectPersistentBuffers(fusion);
}

// Cache tensors before grabbing any references to reductions as cache_before
Expand Down Expand Up @@ -1213,6 +1214,8 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) {
cached_outputs,
dummy_outputs);

normalization_scheduler_utils::fixUpInvalidPersistentBuffers(fusion);

if (rparams.compute_persistent_buffer_with_first_consumer) {
TORCH_INTERNAL_ASSERT(
rparams.persistent_kernel,
Expand Down
199 changes: 199 additions & 0 deletions third_party/nvfuser/csrc/scheduler/normalization_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
#include <fusion.h>
#include <inlining.h>
#include <ir_cloner.h>
#include <lower_trivial_broadcast.h>
#include <ops/arith.h>
#include <scheduler/debug_utils.h>
#include <scheduler/normalization_utils.h>
#include <scheduler/utils.h>
#include <utils.h>

#include <ATen/cuda/CUDAContext.h>

#include <memory>

namespace nvfuser {
namespace normalization_scheduler_utils {

Expand Down Expand Up @@ -494,5 +502,196 @@ std::optional<GridOuterNormalizationParams> getGridOuterNormalizationParams(
return std::nullopt;
}

namespace {

// Go through the resolution points one by one. Resolution points are points
// in which the reduction branch meets the residual branch. These are points
// where the persitent buffer may no longer be needed (one point could be
// after another, and the buffer would be needed until the last resolution
// points)
std::vector<TensorView*> getPersistentUseOfBuffer(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracted from projectPersistentBuffers with no change

TensorView* buffer,
const std::vector<TensorView*>& resolution_points) {
std::vector<TensorView*> persistent_use_of_buffer;

for (auto resolution_point : resolution_points) {
// Need to go through all paths from the persistent buffer to the
// resolution point
auto chains_to_resolution =
DependencyCheck::getAllDependencyChains(buffer, resolution_point);
for (auto chain : chains_to_resolution) {
auto tv_chain = ir_utils::filterByType<TensorView>(chain);

// To move the persistent buffers to the inputs, we need to recompute
// the persistent buffer for all branches that don't go through a
// reduction. If there's a reduction on the current path between the
// persistent buffer and resolution, continue, there's no need to
// replicate this use.
if (std::any_of(tv_chain.begin(), tv_chain.end(), [](TensorView* tv) {
return tv->hasReduction();
})) {
continue;
}

// Grab use of the buffer, chain[0] is the persistent buffer, chain[1]
// is its first use.
auto use = chain[1];

// Only grab unique uses, a persistent buffer could be used multiple
// times in the same expression.
if (std::find(
persistent_use_of_buffer.begin(),
persistent_use_of_buffer.end(),
use) != persistent_use_of_buffer.end()) {
continue;
}
persistent_use_of_buffer.emplace_back(use->as<TensorView>());
}
}

return persistent_use_of_buffer;
}

} // namespace

std::vector<TensorView*> projectPersistentBuffers(Fusion* fusion) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved from reduction_utils.

auto persistent_info = scheduler_utils::persistentBuffers(fusion);
std::vector<TensorView*> dummy_outputs;

// Convenience accessors
const auto& persistent_buffers = persistent_info.persistent_buffers;
const auto& persistent_resolution_points =
persistent_info.persistent_buffer_resolution_points;
const auto& projected_buffers =
persistent_info.projectable_persistent_buffers;

TORCH_INTERNAL_ASSERT(
persistent_buffers.size() == persistent_resolution_points.size());

// Iterate through projected buffers, tracking which index it corresponds too
// since there's a resolution point entry for every buffer.
for (auto buffer_i : c10::irange(persistent_buffers.size())) {
auto buffer = persistent_buffers[buffer_i];
if (std::find(projected_buffers.begin(), projected_buffers.end(), buffer) ==
projected_buffers.end()) {
continue;
}

const auto& resolution_points = persistent_resolution_points.at(buffer_i);

const auto persistent_use_of_buffer =
getPersistentUseOfBuffer(buffer, resolution_points);

// For all uses that do not go towards the reduction operations in the
// persistent section of the graph, recompute the persistent buffer.
for (auto use : persistent_use_of_buffer) {
TORCH_INTERNAL_ASSERT(use->definition() != nullptr);
auto buffer_replicate = RecomputeTv::recompute(buffer);
// Create a shortcut buffer <--> buffer_replicate for propagation.
// Why is this needed?
// Consider that we have a fusion
//
// T0[I]
// T1[b b I] = broadcast(T0)
// T2[b b r] = reduction(T1)
// T3[b b b] = broadcast(T2)
// T4[b, b, I] = T1 + T3
// T5[b, b, r] = reduction(T4)
//
// After projection, it becomes
//
// T0[I]
// T1[b b I] = broadcast(T0)
// T2[b b r] = reduction(T1)
// T3[b b b] = broadcast(T2)
// T6[b b I] = broadcast(T0)
// T4[b, b, I] = T6 + T3
// T5[b, b, r] = reduction(T4)
//
// During schedule, we need to propagate from T2 to T5. However, in the
// resulting DAG, neither the propagation path T2->T3->T4->T5 nor
// T2->T1->T0->T6->T4->T5 works because they both have missing root
// domain. But adding `T7 = T1 + T6` creates a new propagation path
// `T2->T1->T7->T6->T4->T5` which has all root domain information.
// See FusionBroadcastPersistentReduction_CUDA for an example
dummy_outputs.emplace_back(add(buffer_replicate, buffer));
ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate);
}
}
return dummy_outputs;
}

std::unordered_set<TensorView*> fixUpInvalidPersistentBuffers(Fusion* fusion) {
auto persistent_info = scheduler_utils::persistentBuffers(fusion);
const auto& persistent_buffers = persistent_info.persistent_buffers;
const auto& persistent_resolution_points =
persistent_info.persistent_buffer_resolution_points;

std::unique_ptr<ConcretizedBroadcastDomains> concretize_info;

std::unordered_set<TensorView*> recomputed_tvs;

bool recompute_done = false;

for (auto buffer_i : c10::irange(persistent_buffers.size())) {
auto buffer = persistent_buffers.at(buffer_i);

// Check if this buffer needs to be recomputed
bool need_recomputation = false;

for (const auto i : c10::irange(
buffer->getComputeAtPosition(),
buffer->domain()->domain().size())) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't the axes be parallelized consistently even though outside the inline point? I guess it's a tendency that if we parallelize consistently we put that parallel dim within the compute at point, but this check seems to assume that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right. What I thought was that if a broadcast domain appears left of a computeAt position, it's effectively expanded to the mapped consumer domain, so I though that would make the producer indexed consistently. However, if the mapped consumer domain itself may also be a broadcast domain, which effectively means we need something like the RAW sync logic... Will think about it.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, or I need to finish my damn indexing PR, which would have this information 😭

auto axis = buffer->axis(i);

// Unresolved data dependency could exist if:
// - Parallelized by tidx and stored on Local
// - Parallelized by bidx and stored on Local or Shared
if ((axis->isThreadDim() &&
buffer->getMemoryType() == MemoryType::Local) ||
(axis->isBlockDim() &&
(buffer->getMemoryType() == MemoryType::Local ||
buffer->getMemoryType() == MemoryType::Shared))) {
if (!concretize_info) {
concretize_info =
std::make_unique<ConcretizedBroadcastDomains>(fusion);
}

// concretize_info.isConcretized(axis) means the axis is a
// concreteized broadcast domain
if (!concretize_info->isConcretized(axis)) {
continue;
}

need_recomputation = true;
break;
}
}

if (!need_recomputation) {
continue;
}

const auto& resolution_points = persistent_resolution_points.at(buffer_i);

const auto persistent_use_of_buffer =
getPersistentUseOfBuffer(buffer, resolution_points);

for (auto use : persistent_use_of_buffer) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is just adding a new expression for the first use of the persistent buffer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just follows the same logic as the buffer projection. It replicates all the dependent expressions from inputs to the first use of the persistent buffer.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I thought, I'm just forgetting how as it looks like it's only recomputing one use.

TORCH_INTERNAL_ASSERT(use->definition() != nullptr);
auto buffer_replicate = RecomputeTv::recompute(buffer);
ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate);
recomputed_tvs.insert(buffer);
recompute_done = true;
}
}

if (recompute_done) {
inlineMost();
}

return recomputed_tvs;
}

} // namespace normalization_scheduler_utils
} // namespace nvfuser
22 changes: 22 additions & 0 deletions third_party/nvfuser/csrc/scheduler/normalization_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include <vector>

namespace nvfuser {

class Fusion;

namespace normalization_scheduler_utils {

//! Utility class to iterate candidates of launch configurations in a
Expand Down Expand Up @@ -145,5 +148,24 @@ std::optional<GridOuterNormalizationParams> getGridOuterNormalizationParams(
int64_t vectorize_factor,
int64_t persistent_buffer_size);

// Take all projectable persistent buffers, and move them to the inputs. This
// function create dummy outputs which should be used in later stages of the
// scheduling.
TORCH_CUDA_CU_API std::vector<TensorView*> projectPersistentBuffers(
Fusion* fusion);

//! Persistent buffers are effectively tensors that cannot be inlined
//! due to the reduction and broadcast pattern. Since we store
//! persistent buffers in registers, they have to be parallelized in
//! such a way that no data dependency exists between threads. Buffers
//! that do have dependencies cannot be persistent. This function
//! detects such buffers and make them non-persistent by
//! recomputation. Returns buffers that are made non-persistent.
//!
//! Alternatively, such buffers could be stored on shared memory if
//! the dependecy only exists between threads in the same thread
//! block. Not considered yet.
std::unordered_set<TensorView*> fixUpInvalidPersistentBuffers(Fusion* fusion);

} // namespace normalization_scheduler_utils
} // namespace nvfuser
106 changes: 0 additions & 106 deletions third_party/nvfuser/csrc/scheduler/reduction_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,111 +656,5 @@ TensorView* sortAndRFactor(TensorView* reference_tv) {
return ir_utils::rfactorHelper(reference_tv, rfactor_axes);
}

std::vector<TensorView*> projectPersistentBuffers(Fusion* fusion) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved

auto persistent_info = scheduler_utils::persistentBuffers(fusion);
std::vector<TensorView*> dummy_outputs;

// Convenience accessors
const auto& persistent_buffers = persistent_info.persistent_buffers;
const auto& persistent_resolution_points =
persistent_info.persistent_buffer_resolution_points;
const auto& projected_buffers =
persistent_info.projectable_persistent_buffers;

TORCH_INTERNAL_ASSERT(
persistent_buffers.size() == persistent_resolution_points.size());

// Iterate through projected buffers, tracking which index it corresponds too
// since there's a resolution point entry for every buffer.
for (auto buffer_i : c10::irange(persistent_buffers.size())) {
auto buffer = persistent_buffers[buffer_i];
if (std::find(projected_buffers.begin(), projected_buffers.end(), buffer) ==
projected_buffers.end()) {
continue;
}

auto resolution_points = persistent_resolution_points[buffer_i];

std::vector<Val*> persistent_use_of_buffer;

// Go through the resolution points one by one. Resolution points are points
// in which the reduction branch meets the residual branch. These are points
// where the persitent buffer may no longer be needed (one point could be
// after another, and the buffer would be needed until the last resolution
// points)
for (auto resolution_point : resolution_points) {
// Need to go through all paths from the persistent buffer to the
// resolution point
auto chains_to_resolution =
DependencyCheck::getAllDependencyChains(buffer, resolution_point);
for (auto chain : chains_to_resolution) {
auto tv_chain = ir_utils::filterByType<TensorView>(chain);

// To move the persistent buffers to the inputs, we need to recompute
// the persistent buffer for all branches that don't go through a
// reduction. If there's a reduction on the current path between the
// persistent buffer and resolution, continue, there's no need to
// replicate this use.
if (std::any_of(tv_chain.begin(), tv_chain.end(), [](TensorView* tv) {
return tv->hasReduction();
})) {
continue;
}

// Grab use of the buffer, chain[0] is the persistent buffer, chain[1]
// is its first use.
auto use = chain[1];

// Only grab unique uses, a persistent buffer could be used multiple
// times in the same expression.
if (std::find(
persistent_use_of_buffer.begin(),
persistent_use_of_buffer.end(),
use) != persistent_use_of_buffer.end()) {
continue;
}
persistent_use_of_buffer.emplace_back(use);
}
}

// For all uses that do not go towards the reduction operations in the
// persistent section of the graph, recompute the persistent buffer.
for (auto use : persistent_use_of_buffer) {
TORCH_INTERNAL_ASSERT(use->definition() != nullptr);
auto buffer_replicate = RecomputeTv::recompute(buffer);
// Create a shortcut buffer <--> buffer_replicate for propagation.
// Why is this needed?
// Consider that we have a fusion
//
// T0[I]
// T1[b b I] = broadcast(T0)
// T2[b b r] = reduction(T1)
// T3[b b b] = broadcast(T2)
// T4[b, b, I] = T1 + T3
// T5[b, b, r] = reduction(T4)
//
// After projection, it becomes
//
// T0[I]
// T1[b b I] = broadcast(T0)
// T2[b b r] = reduction(T1)
// T3[b b b] = broadcast(T2)
// T6[b b I] = broadcast(T0)
// T4[b, b, I] = T6 + T3
// T5[b, b, r] = reduction(T4)
//
// During schedule, we need to propagate from T2 to T5. However, in the
// resulting DAG, neither the propagation path T2->T3->T4->T5 nor
// T2->T1->T0->T6->T4->T5 works because they both have missing root
// domain. But adding `T7 = T1 + T6` creates a new propagation path
// `T2->T1->T7->T6->T4->T5` which has all root domain information.
// See FusionBroadcastPersistentReduction_CUDA for an example
dummy_outputs.emplace_back(add(buffer_replicate, buffer));
ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate);
}
}
return dummy_outputs;
}

} // namespace reduction_scheduler_utils
} // namespace nvfuser
6 changes: 0 additions & 6 deletions third_party/nvfuser/csrc/scheduler/reduction_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,5 @@ TORCH_CUDA_CU_API void multiReductionInliner(
// Reduction inliner expects an rfactored domain.
TORCH_CUDA_CU_API TensorView* sortAndRFactor(TensorView* reference_tv);

// Take all projectable persistent buffers, and move them to the inputs. This
// function create dummy outputs which should be used in later stages of the
// scheduling.
TORCH_CUDA_CU_API std::vector<TensorView*> projectPersistentBuffers(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved

Fusion* fusion);

} // namespace reduction_scheduler_utils
} // namespace nvfuser
Loading