-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make resolved broadcast tensors non-persistent #2563
base: devel
Are you sure you want to change the base?
Changes from all commits
9465b73
dc4b796
2e60d1b
94c7f70
f98bbd6
2299978
650ef8e
30cc5b9
b82289d
9b5e343
2effde5
2516f6a
f70953d
5011635
ff50f47
35f0990
47f5b24
08dd7f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,17 @@ | ||
#include <fusion.h> | ||
#include <inlining.h> | ||
#include <ir_cloner.h> | ||
#include <lower_trivial_broadcast.h> | ||
#include <ops/arith.h> | ||
#include <scheduler/debug_utils.h> | ||
#include <scheduler/normalization_utils.h> | ||
#include <scheduler/utils.h> | ||
#include <utils.h> | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
|
||
#include <memory> | ||
|
||
namespace nvfuser { | ||
namespace normalization_scheduler_utils { | ||
|
||
|
@@ -494,5 +502,196 @@ std::optional<GridOuterNormalizationParams> getGridOuterNormalizationParams( | |
return std::nullopt; | ||
} | ||
|
||
namespace { | ||
|
||
// Go through the resolution points one by one. Resolution points are points | ||
// in which the reduction branch meets the residual branch. These are points | ||
// where the persitent buffer may no longer be needed (one point could be | ||
// after another, and the buffer would be needed until the last resolution | ||
// points) | ||
std::vector<TensorView*> getPersistentUseOfBuffer( | ||
TensorView* buffer, | ||
const std::vector<TensorView*>& resolution_points) { | ||
std::vector<TensorView*> persistent_use_of_buffer; | ||
|
||
for (auto resolution_point : resolution_points) { | ||
// Need to go through all paths from the persistent buffer to the | ||
// resolution point | ||
auto chains_to_resolution = | ||
DependencyCheck::getAllDependencyChains(buffer, resolution_point); | ||
for (auto chain : chains_to_resolution) { | ||
auto tv_chain = ir_utils::filterByType<TensorView>(chain); | ||
|
||
// To move the persistent buffers to the inputs, we need to recompute | ||
// the persistent buffer for all branches that don't go through a | ||
// reduction. If there's a reduction on the current path between the | ||
// persistent buffer and resolution, continue, there's no need to | ||
// replicate this use. | ||
if (std::any_of(tv_chain.begin(), tv_chain.end(), [](TensorView* tv) { | ||
return tv->hasReduction(); | ||
})) { | ||
continue; | ||
} | ||
|
||
// Grab use of the buffer, chain[0] is the persistent buffer, chain[1] | ||
// is its first use. | ||
auto use = chain[1]; | ||
|
||
// Only grab unique uses, a persistent buffer could be used multiple | ||
// times in the same expression. | ||
if (std::find( | ||
persistent_use_of_buffer.begin(), | ||
persistent_use_of_buffer.end(), | ||
use) != persistent_use_of_buffer.end()) { | ||
continue; | ||
} | ||
persistent_use_of_buffer.emplace_back(use->as<TensorView>()); | ||
} | ||
} | ||
|
||
return persistent_use_of_buffer; | ||
} | ||
|
||
} // namespace | ||
|
||
std::vector<TensorView*> projectPersistentBuffers(Fusion* fusion) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved from reduction_utils. |
||
auto persistent_info = scheduler_utils::persistentBuffers(fusion); | ||
std::vector<TensorView*> dummy_outputs; | ||
|
||
// Convenience accessors | ||
const auto& persistent_buffers = persistent_info.persistent_buffers; | ||
const auto& persistent_resolution_points = | ||
persistent_info.persistent_buffer_resolution_points; | ||
const auto& projected_buffers = | ||
persistent_info.projectable_persistent_buffers; | ||
|
||
TORCH_INTERNAL_ASSERT( | ||
persistent_buffers.size() == persistent_resolution_points.size()); | ||
|
||
// Iterate through projected buffers, tracking which index it corresponds too | ||
// since there's a resolution point entry for every buffer. | ||
for (auto buffer_i : c10::irange(persistent_buffers.size())) { | ||
auto buffer = persistent_buffers[buffer_i]; | ||
if (std::find(projected_buffers.begin(), projected_buffers.end(), buffer) == | ||
projected_buffers.end()) { | ||
continue; | ||
} | ||
|
||
const auto& resolution_points = persistent_resolution_points.at(buffer_i); | ||
|
||
const auto persistent_use_of_buffer = | ||
getPersistentUseOfBuffer(buffer, resolution_points); | ||
|
||
// For all uses that do not go towards the reduction operations in the | ||
// persistent section of the graph, recompute the persistent buffer. | ||
for (auto use : persistent_use_of_buffer) { | ||
TORCH_INTERNAL_ASSERT(use->definition() != nullptr); | ||
auto buffer_replicate = RecomputeTv::recompute(buffer); | ||
// Create a shortcut buffer <--> buffer_replicate for propagation. | ||
// Why is this needed? | ||
// Consider that we have a fusion | ||
// | ||
// T0[I] | ||
// T1[b b I] = broadcast(T0) | ||
// T2[b b r] = reduction(T1) | ||
// T3[b b b] = broadcast(T2) | ||
// T4[b, b, I] = T1 + T3 | ||
// T5[b, b, r] = reduction(T4) | ||
// | ||
// After projection, it becomes | ||
// | ||
// T0[I] | ||
// T1[b b I] = broadcast(T0) | ||
// T2[b b r] = reduction(T1) | ||
// T3[b b b] = broadcast(T2) | ||
// T6[b b I] = broadcast(T0) | ||
// T4[b, b, I] = T6 + T3 | ||
// T5[b, b, r] = reduction(T4) | ||
// | ||
// During schedule, we need to propagate from T2 to T5. However, in the | ||
// resulting DAG, neither the propagation path T2->T3->T4->T5 nor | ||
// T2->T1->T0->T6->T4->T5 works because they both have missing root | ||
// domain. But adding `T7 = T1 + T6` creates a new propagation path | ||
// `T2->T1->T7->T6->T4->T5` which has all root domain information. | ||
// See FusionBroadcastPersistentReduction_CUDA for an example | ||
dummy_outputs.emplace_back(add(buffer_replicate, buffer)); | ||
ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); | ||
} | ||
} | ||
return dummy_outputs; | ||
} | ||
|
||
std::unordered_set<TensorView*> fixUpInvalidPersistentBuffers(Fusion* fusion) { | ||
auto persistent_info = scheduler_utils::persistentBuffers(fusion); | ||
const auto& persistent_buffers = persistent_info.persistent_buffers; | ||
const auto& persistent_resolution_points = | ||
persistent_info.persistent_buffer_resolution_points; | ||
|
||
std::unique_ptr<ConcretizedBroadcastDomains> concretize_info; | ||
|
||
std::unordered_set<TensorView*> recomputed_tvs; | ||
|
||
bool recompute_done = false; | ||
|
||
for (auto buffer_i : c10::irange(persistent_buffers.size())) { | ||
auto buffer = persistent_buffers.at(buffer_i); | ||
|
||
// Check if this buffer needs to be recomputed | ||
bool need_recomputation = false; | ||
|
||
for (const auto i : c10::irange( | ||
buffer->getComputeAtPosition(), | ||
buffer->domain()->domain().size())) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Couldn't the axes be parallelized consistently even though outside the inline point? I guess it's a tendency that if we parallelize consistently we put that parallel dim within the compute at point, but this check seems to assume that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, you're right. What I thought was that if a broadcast domain appears left of a computeAt position, it's effectively expanded to the mapped consumer domain, so I though that would make the producer indexed consistently. However, if the mapped consumer domain itself may also be a broadcast domain, which effectively means we need something like the RAW sync logic... Will think about it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, or I need to finish my damn indexing PR, which would have this information 😭 |
||
auto axis = buffer->axis(i); | ||
|
||
// Unresolved data dependency could exist if: | ||
// - Parallelized by tidx and stored on Local | ||
// - Parallelized by bidx and stored on Local or Shared | ||
if ((axis->isThreadDim() && | ||
buffer->getMemoryType() == MemoryType::Local) || | ||
(axis->isBlockDim() && | ||
(buffer->getMemoryType() == MemoryType::Local || | ||
buffer->getMemoryType() == MemoryType::Shared))) { | ||
if (!concretize_info) { | ||
concretize_info = | ||
std::make_unique<ConcretizedBroadcastDomains>(fusion); | ||
} | ||
|
||
// concretize_info.isConcretized(axis) means the axis is a | ||
// concreteized broadcast domain | ||
if (!concretize_info->isConcretized(axis)) { | ||
continue; | ||
} | ||
|
||
need_recomputation = true; | ||
break; | ||
} | ||
} | ||
|
||
if (!need_recomputation) { | ||
continue; | ||
} | ||
|
||
const auto& resolution_points = persistent_resolution_points.at(buffer_i); | ||
|
||
const auto persistent_use_of_buffer = | ||
getPersistentUseOfBuffer(buffer, resolution_points); | ||
|
||
for (auto use : persistent_use_of_buffer) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this is just adding a new expression for the first use of the persistent buffer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This just follows the same logic as the buffer projection. It replicates all the dependent expressions from inputs to the first use of the persistent buffer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's what I thought, I'm just forgetting how as it looks like it's only recomputing one use. |
||
TORCH_INTERNAL_ASSERT(use->definition() != nullptr); | ||
auto buffer_replicate = RecomputeTv::recompute(buffer); | ||
ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); | ||
recomputed_tvs.insert(buffer); | ||
recompute_done = true; | ||
} | ||
} | ||
|
||
if (recompute_done) { | ||
inlineMost(); | ||
} | ||
|
||
return recomputed_tvs; | ||
} | ||
|
||
} // namespace normalization_scheduler_utils | ||
} // namespace nvfuser |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -656,111 +656,5 @@ TensorView* sortAndRFactor(TensorView* reference_tv) { | |
return ir_utils::rfactorHelper(reference_tv, rfactor_axes); | ||
} | ||
|
||
std::vector<TensorView*> projectPersistentBuffers(Fusion* fusion) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved |
||
auto persistent_info = scheduler_utils::persistentBuffers(fusion); | ||
std::vector<TensorView*> dummy_outputs; | ||
|
||
// Convenience accessors | ||
const auto& persistent_buffers = persistent_info.persistent_buffers; | ||
const auto& persistent_resolution_points = | ||
persistent_info.persistent_buffer_resolution_points; | ||
const auto& projected_buffers = | ||
persistent_info.projectable_persistent_buffers; | ||
|
||
TORCH_INTERNAL_ASSERT( | ||
persistent_buffers.size() == persistent_resolution_points.size()); | ||
|
||
// Iterate through projected buffers, tracking which index it corresponds too | ||
// since there's a resolution point entry for every buffer. | ||
for (auto buffer_i : c10::irange(persistent_buffers.size())) { | ||
auto buffer = persistent_buffers[buffer_i]; | ||
if (std::find(projected_buffers.begin(), projected_buffers.end(), buffer) == | ||
projected_buffers.end()) { | ||
continue; | ||
} | ||
|
||
auto resolution_points = persistent_resolution_points[buffer_i]; | ||
|
||
std::vector<Val*> persistent_use_of_buffer; | ||
|
||
// Go through the resolution points one by one. Resolution points are points | ||
// in which the reduction branch meets the residual branch. These are points | ||
// where the persitent buffer may no longer be needed (one point could be | ||
// after another, and the buffer would be needed until the last resolution | ||
// points) | ||
for (auto resolution_point : resolution_points) { | ||
// Need to go through all paths from the persistent buffer to the | ||
// resolution point | ||
auto chains_to_resolution = | ||
DependencyCheck::getAllDependencyChains(buffer, resolution_point); | ||
for (auto chain : chains_to_resolution) { | ||
auto tv_chain = ir_utils::filterByType<TensorView>(chain); | ||
|
||
// To move the persistent buffers to the inputs, we need to recompute | ||
// the persistent buffer for all branches that don't go through a | ||
// reduction. If there's a reduction on the current path between the | ||
// persistent buffer and resolution, continue, there's no need to | ||
// replicate this use. | ||
if (std::any_of(tv_chain.begin(), tv_chain.end(), [](TensorView* tv) { | ||
return tv->hasReduction(); | ||
})) { | ||
continue; | ||
} | ||
|
||
// Grab use of the buffer, chain[0] is the persistent buffer, chain[1] | ||
// is its first use. | ||
auto use = chain[1]; | ||
|
||
// Only grab unique uses, a persistent buffer could be used multiple | ||
// times in the same expression. | ||
if (std::find( | ||
persistent_use_of_buffer.begin(), | ||
persistent_use_of_buffer.end(), | ||
use) != persistent_use_of_buffer.end()) { | ||
continue; | ||
} | ||
persistent_use_of_buffer.emplace_back(use); | ||
} | ||
} | ||
|
||
// For all uses that do not go towards the reduction operations in the | ||
// persistent section of the graph, recompute the persistent buffer. | ||
for (auto use : persistent_use_of_buffer) { | ||
TORCH_INTERNAL_ASSERT(use->definition() != nullptr); | ||
auto buffer_replicate = RecomputeTv::recompute(buffer); | ||
// Create a shortcut buffer <--> buffer_replicate for propagation. | ||
// Why is this needed? | ||
// Consider that we have a fusion | ||
// | ||
// T0[I] | ||
// T1[b b I] = broadcast(T0) | ||
// T2[b b r] = reduction(T1) | ||
// T3[b b b] = broadcast(T2) | ||
// T4[b, b, I] = T1 + T3 | ||
// T5[b, b, r] = reduction(T4) | ||
// | ||
// After projection, it becomes | ||
// | ||
// T0[I] | ||
// T1[b b I] = broadcast(T0) | ||
// T2[b b r] = reduction(T1) | ||
// T3[b b b] = broadcast(T2) | ||
// T6[b b I] = broadcast(T0) | ||
// T4[b, b, I] = T6 + T3 | ||
// T5[b, b, r] = reduction(T4) | ||
// | ||
// During schedule, we need to propagate from T2 to T5. However, in the | ||
// resulting DAG, neither the propagation path T2->T3->T4->T5 nor | ||
// T2->T1->T0->T6->T4->T5 works because they both have missing root | ||
// domain. But adding `T7 = T1 + T6` creates a new propagation path | ||
// `T2->T1->T7->T6->T4->T5` which has all root domain information. | ||
// See FusionBroadcastPersistentReduction_CUDA for an example | ||
dummy_outputs.emplace_back(add(buffer_replicate, buffer)); | ||
ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); | ||
} | ||
} | ||
return dummy_outputs; | ||
} | ||
|
||
} // namespace reduction_scheduler_utils | ||
} // namespace nvfuser |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,11 +41,5 @@ TORCH_CUDA_CU_API void multiReductionInliner( | |
// Reduction inliner expects an rfactored domain. | ||
TORCH_CUDA_CU_API TensorView* sortAndRFactor(TensorView* reference_tv); | ||
|
||
// Take all projectable persistent buffers, and move them to the inputs. This | ||
// function create dummy outputs which should be used in later stages of the | ||
// scheduling. | ||
TORCH_CUDA_CU_API std::vector<TensorView*> projectPersistentBuffers( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved |
||
Fusion* fusion); | ||
|
||
} // namespace reduction_scheduler_utils | ||
} // namespace nvfuser |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extracted from
projectPersistentBuffers
with no change