Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make resolved broadcast tensors non-persistent #2563

Open
wants to merge 18 commits into
base: devel
Choose a base branch
from
Open

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Mar 9, 2023

Fixes #2559

See #2559 for a C++ repro. The fusion math is:

Inputs:
  T0_g[ iS71{( ceilDiv(( ceilDiv(i0, 256) ), 8) )}, iS72{8}, iS70{256} ], __half
  T1_g[ iS38{( ceilDiv(( ceilDiv(( i2 * i3 ), 256) ), 8) )}, iS39{8}, iS37{256} ], __half
Outputs:
  T10_g[ iblockIdx.x58{( ceilDiv(( ceilDiv(( i2 * i3 ), 256) ), 8) )}, iS59{8}, ithreadIdx.x57{256} ] ca_pos( 3 ) produce_pos( 3 ), float

%kernel_math {
T3_l[ iblockIdx.x33{( ceilDiv(( ceilDiv(( i2 * i3 ), 256) ), 8) )}, iS34{8}, ithreadIdx.x32{256} ]
   = __half2float(T1_g[ iS38{( ceilDiv(( ceilDiv(( i2 * i3 ), 256) ), 8) )}, iS39{8}, iS37{256} ]);
T2_l[ iblockIdx.x67{( ceilDiv(( ceilDiv(i0, 256) ), 8) )}, iS68{8}, ithreadIdx.x66{256} ] ca_pos( 3 )
   = __half2float(T0_g[ iS71{( ceilDiv(( ceilDiv(i0, 256) ), 8) )}, iS72{8}, iS70{256} ]);
T4_l[ iblockIdx.x63{( ceilDiv(( ceilDiv(( i0 * 1 ), 256) ), 8) )}, iS64{8}, ithreadIdx.x62{256} ] produce_pos( 3 )
   = broadcast( T2_l[ iblockIdx.x67{( ceilDiv(( ceilDiv(i0, 256) ), 8) )}, iS68{8}, ithreadIdx.x66{256} ] ca_pos( 3 ) )
T5_l[ iblockIdx.x28{( ceilDiv(( ceilDiv(( i0 * i3 ), 256) ), 8) )}, iS29{8}, ithreadIdx.x27{256} ] ca_pos( 3 )
   = T4_l[ iblockIdx.x63{( ceilDiv(( ceilDiv(( i0 * 1 ), 256) ), 8) )}, iS64{8}, ithreadIdx.x62{256} ] produce_pos( 3 )
   + T3_l[ iblockIdx.x33{( ceilDiv(( ceilDiv(( i2 * i3 ), 256) ), 8) )}, iS34{8}, ithreadIdx.x32{256} ];
T11_l[ iblockIdx.x78{( ceilDiv(( ceilDiv(( i0 * i3 ), 256) ), 8) )}rf, rS79{8}rf, ithreadIdx.x77{256}rf ] ca_pos( 1 ) produce_pos( 3 )
   = reduction( T5_l[ iblockIdx.x28{( ceilDiv(( ceilDiv(( i0 * i3 ), 256) ), 8) )}, iS29{8}, ithreadIdx.x27{256} ] ca_pos( 3 ), op = add, initial value = double(0), allreduce = false )
T6_l[ rblockIdx.x80{( ceilDiv(( ceilDiv(( i0 * i3 ), 256) ), 8) )}, rthreadIdx.x81{256} ] produce_pos( 1 )
   = reduction( T11_l[ iblockIdx.x78{( ceilDiv(( ceilDiv(( i0 * i3 ), 256) ), 8) )}rf, rS79{8}rf, ithreadIdx.x77{256}rf ] ca_pos( 1 ) produce_pos( 3 ), op = add, initial value = double(0), allreduce = false )
T7_l[ bblockIdx.x48{( ceilDiv(( ceilDiv(( 1 * 1 ), 256) ), 8) )}, bS49{8}, bthreadIdx.x47{256} ]
   = broadcast( T6_l[ rblockIdx.x80{( ceilDiv(( ceilDiv(( i0 * i3 ), 256) ), 8) )}, rthreadIdx.x81{256} ] produce_pos( 1 ) )
T8_l[ iblockIdx.x43{( ceilDiv(( ceilDiv(( i2 * i3 ), 256) ), 8) )}, iS44{8}, ithreadIdx.x42{256} ] ca_pos( 3 )
   = T3_l[ iblockIdx.x33{( ceilDiv(( ceilDiv(( i2 * i3 ), 256) ), 8) )}, iS34{8}, ithreadIdx.x32{256} ]
   + T7_l[ bblockIdx.x48{( ceilDiv(( ceilDiv(( 1 * 1 ), 256) ), 8) )}, bS49{8}, bthreadIdx.x47{256} ];
T9_l[ iblockIdx.x53{( ceilDiv(( ceilDiv(( i0 * 1 ), 256) ), 8) )}, iS54{8}, ithreadIdx.x52{256} ] ca_pos( 3 )
   = T4_l[ iblockIdx.x63{( ceilDiv(( ceilDiv(( i0 * 1 ), 256) ), 8) )}, iS64{8}, ithreadIdx.x62{256} ] produce_pos( 3 )
   + T7_l[ bblockIdx.x48{( ceilDiv(( ceilDiv(( 1 * 1 ), 256) ), 8) )}, bS49{8}, bthreadIdx.x47{256} ];
T10_g[ iblockIdx.x58{( ceilDiv(( ceilDiv(( i2 * i3 ), 256) ), 8) )}, iS59{8}, ithreadIdx.x57{256} ] ca_pos( 3 ) produce_pos( 3 )
   = T8_l[ iblockIdx.x43{( ceilDiv(( ceilDiv(( i2 * i3 ), 256) ), 8) )}, iS44{8}, ithreadIdx.x42{256} ] ca_pos( 3 )
   + T9_l[ iblockIdx.x53{( ceilDiv(( ceilDiv(( i0 * 1 ), 256) ), 8) )}, iS54{8}, ithreadIdx.x52{256} ] ca_pos( 3 );
}

Here, there are two persistent buffers:

T3_l[ iblockIdx.x33{( ceilDiv(( ceilDiv(( i2 * i3 ), 256) ), 8) )}, iS34{8}, ithreadIdx.x32{256} ]
T4_l[ iblockIdx.x63{( ceilDiv(( ceilDiv(( i0 * 1 ), 256) ), 8) )}, iS64{8}, ithreadIdx.x62{256} ] produce_pos( 3 )

They are the producers of T5:

T5_l[ iblockIdx.x28{( ceilDiv(( ceilDiv(( i0 * i3 ), 256) ), 8) )}, iS29{8}, ithreadIdx.x27{256} ] ca_pos( 3 )
   = T4_l[ iblockIdx.x63{( ceilDiv(( ceilDiv(( i0 * 1 ), 256) ), 8) )}, iS64{8}, ithreadIdx.x62{256} ] produce_pos( 3 )
   + T3_l[ iblockIdx.x33{( ceilDiv(( ceilDiv(( i2 * i3 ), 256) ), 8) )}, iS34{8}, ithreadIdx.x32{256} ];

Since T3 and T4 are not inlined at all and are on Local, they must be exactly mapped with T5 with the same parallelization, which is not the case with T4, thus the parallel validation gives an error:

C++ exception with description "producer->getMemoryType() == MemoryType::Global INTERNAL ASSERT FAILED at "/raid/nmaruyama/debug3/third_party/nvfuser/csrc/lower_sync_information.cpp":770, please report a bug to PyTorch. Inconsistent parallelization found between TV4 (T4_l[ iblockIdx.x63{( ceilDiv(( ceilDiv(( T0.size[0] * 1 ), 256) ), 8) )}, iS64{8}, ithreadIdx.x62{256} ] produce_pos( 3 )) and TV5(T5_l[ iblockIdx.x28{( ceilDiv(( ceilDiv(( T0.size[0] * T1.size[1] ), 256) ), 8) )}, iS29{8}, ithreadIdx.x27{256} ] ca_pos( 3 )). Producer is required to be in Global Memory based on parallelization strategy. RAW flags: (blockIdx.x threadIdx.x)

The workaround of this PR is to give up making T4 persistent by recomputing it, much like the projection (reused most of the logic). Alternatively, if a buffer was just parallelized by TID, then saving it on Shared would work too, but not considered.

None of the existing C++ tests and benchmarks is affected by this PR (except for a slight change of the buffer projection).

@naoyam naoyam marked this pull request as draft March 9, 2023 12:19
@naoyam naoyam marked this pull request as ready for review March 11, 2023 08:39
@naoyam naoyam changed the title [DRAFT] Make broadcast tensors non-persistent Make broadcast tensors non-persistent Mar 11, 2023
@naoyam naoyam changed the title Make broadcast tensors non-persistent Make resolved broadcast tensors non-persistent Mar 11, 2023
// Take all projectable persistent buffers, and move them to the inputs. This
// function create dummy outputs which should be used in later stages of the
// scheduling.
TORCH_CUDA_CU_API std::vector<TensorView*> projectPersistentBuffers(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved

@@ -656,111 +656,5 @@ TensorView* sortAndRFactor(TensorView* reference_tv) {
return ir_utils::rfactorHelper(reference_tv, rfactor_axes);
}

std::vector<TensorView*> projectPersistentBuffers(Fusion* fusion) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved

// where the persitent buffer may no longer be needed (one point could be
// after another, and the buffer would be needed until the last resolution
// points)
std::vector<TensorView*> getPersistentUseOfBuffer(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracted from projectPersistentBuffers with no change


} // namespace

std::vector<TensorView*> projectPersistentBuffers(Fusion* fusion) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved from reduction_utils.

third_party/nvfuser/csrc/scheduler/normalization_utils.cpp Outdated Show resolved Hide resolved
@naoyam naoyam requested a review from csarofeen March 11, 2023 08:58
Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple questions, I'm not quite understanding the approach.


for (const auto i : c10::irange(
buffer->getComputeAtPosition(),
buffer->domain()->domain().size())) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't the axes be parallelized consistently even though outside the inline point? I guess it's a tendency that if we parallelize consistently we put that parallel dim within the compute at point, but this check seems to assume that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right. What I thought was that if a broadcast domain appears left of a computeAt position, it's effectively expanded to the mapped consumer domain, so I though that would make the producer indexed consistently. However, if the mapped consumer domain itself may also be a broadcast domain, which effectively means we need something like the RAW sync logic... Will think about it.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, or I need to finish my damn indexing PR, which would have this information 😭

const auto persistent_use_of_buffer =
getPersistentUseOfBuffer(buffer, resolution_points);

for (auto use : persistent_use_of_buffer) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is just adding a new expression for the first use of the persistent buffer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just follows the same logic as the buffer projection. It replicates all the dependent expressions from inputs to the first use of the persistent buffer.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I thought, I'm just forgetting how as it looks like it's only recomputing one use.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants