Skip to content

Commit

Permalink
save-tmp-work
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Oct 6, 2022
1 parent 80198b2 commit 31e8d0a
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 58 deletions.
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,8 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
vectorize_factor = 1;
}

// Try expanding vectorization to contig merged domains
vectorize_factor = vectorize_helper::expandVectorizationToContigMergedDomains(
// Try adjusting vectorization to contig merged domains
vectorize_factor = vectorize_helper::adjustVectorizationToContigMergedDomains(
fusion,
runtime_info,
vectorizable_inputs_outputs,
Expand Down
17 changes: 17 additions & 0 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,23 @@ std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
vectorize_factor = std::min(vectorize_factor, tv_vectorize_factor);
}

// Try adjusting vectorization to contig merged domains
// TODO: This is an expensive function that shouldn't be in heuristics without
// caching.
auto adjusted_vector_word_size =
vectorize_helper::adjustVectorizationToContigMergedDomains(
fusion,
runtime_info,
vectorizable_inputs_outputs,
largest_out,
break_point,
vectorize_factor);

vectorize_factor = std::min(
static_cast<size_t>(max_unroll_factor), adjusted_vector_word_size);

std::cout << "[in scheduler] vectorize_factor: " << vectorize_factor << std::endl;

if (vectorize_factor == 1) {
params->vectorize = false;
params->unroll_factor = max_unroll_factor;
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -953,8 +953,8 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
vectorize_factor = 1;
}

// Try expanding vectorization to contig merged domains
vectorize_factor = vectorize_helper::expandVectorizationToContigMergedDomains(
// Try adjusting vectorization to contig merged domains
vectorize_factor = vectorize_helper::adjustVectorizationToContigMergedDomains(
fusion,
runtime_info,
vectorizable_inputs_outputs,
Expand Down
80 changes: 42 additions & 38 deletions torch/csrc/jit/codegen/cuda/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,14 +580,22 @@ size_t SchedulerRuntimeInfo::getMaxVectorizableWidth(TensorView* tv) {

auto tv_root_no_reductions = TensorDomain::noReductions(tv_root);

std::vector<int64_t> new_to_old_map;
auto new2old = [&new_to_old_map](int64_t new_pos) {
return new_to_old_map.empty() ? new_pos : new_to_old_map[new_pos];
};

auto contiguity = tv->domain()->contiguity();
// Appears after reductions the reduction domain often has a contiguity entry.
// This only matters if the result of the reduction is an output
if (contiguity.size() == tv_root.size() &&
contiguity.size() != tv_root_no_reductions.size()) {
std::vector<bool> new_contiguity;
new_to_old_map.reserve(tv_root_no_reductions.size());
new_contiguity.reserve(tv_root_no_reductions.size());
for (auto i : c10::irange(tv_root.size())) {
if (!tv_root[i]->isReduction()) {
new_to_old_map.push_back(i);
new_contiguity.push_back(contiguity[i]);
}
}
Expand Down Expand Up @@ -619,11 +627,12 @@ size_t SchedulerRuntimeInfo::getMaxVectorizableWidth(TensorView* tv) {
}

auto numel = 1;
for (auto i : c10::irange(tv_root_size)) {
auto root_i = tv_root_size - i - 1;
int64_t root_i = tv_root_size - 1;
while (root_i >= 0) {
auto root_id = tv_root[root_i];

if (root_id->extent()->isOneInt() || root_id->isBroadcast()) {
root_i--;
continue;
}

Expand All @@ -640,15 +649,40 @@ size_t SchedulerRuntimeInfo::getMaxVectorizableWidth(TensorView* tv) {

// Still contiguous
numel *= dim_size->as<int64_t>();

root_i--;
}

int64_t root_pos_discontig = root_i;

size_t stride_alignment;
if (root_pos_discontig < 0) {
// The entire tensor is contiguous. For such case, we just set
// stride_alignment to max_alignment_size_in_byte so all next_vector_size
// checks below against it passes.
stride_alignment = max_alignment_size_in_byte;
} else {
// root_pos_discontig is the right-most discontiguous dimension
stride_alignment = getStrideAlignmentSize(
tv, new2old(root_pos_discontig)); // TODO: use old pos
for (auto pos = root_pos_discontig - 1; pos >= 0; pos--) {
if (!contiguity[pos]) {
stride_alignment = std::min(
stride_alignment,
getStrideAlignmentSize(tv, new2old(pos))); // TODO: use old pos
}
}
}
stride_alignment /= item_size;

// Assuming intermediate tensors have friendly alignment, and
// all contiguity true. Determine the largest power of 2 below
// innermost dimension size for the word size of vectorizaiton
size_t vector_size = 1;
size_t next_vector_size = 2;
while (next_vector_size <= max_vector_size && next_vector_size <= numel &&
numel % next_vector_size == 0) {
numel % next_vector_size == 0 &&
stride_alignment % next_vector_size == 0) {
vector_size = next_vector_size;
next_vector_size *= 2;
}
Expand Down Expand Up @@ -724,32 +758,9 @@ size_t SchedulerRuntimeInfo::getInnerDimVectorizableWidth(TensorView* tv) {
return 1;
}

int id_pos_discontig = id_pos;
while (id_pos_discontig >= 0 && contiguity[id_pos_discontig]) {
id_pos_discontig--;
}

size_t item_size =
dataTypeSize(tv->dtype(), indexModeToDtype(getIndexMode()));

size_t inner_dimension_stride;
if (id_pos_discontig < 0) {
// The entire tensor is contiguous. For such case, we just set
// inner_dimension_stride to zero so all next_vector_size checks below
// against it passes.
inner_dimension_stride = 0;
} else {
// id_pos_discontig is the right-most discontiguous dimension
inner_dimension_stride = getStrideAlignmentSize(tv, id_pos_discontig);
for (auto pos = id_pos_discontig - 1; pos >= 0; pos--) {
if (!contiguity[pos]) {
inner_dimension_stride =
std::min(inner_dimension_stride, getStrideAlignmentSize(tv, pos));
}
}
}
inner_dimension_stride /= item_size;

// Alignment should always at least be the data type size
TORCH_INTERNAL_ASSERT(getAlignmentSize(tv) % item_size == 0);
size_t max_vector_size = getAlignmentSize(tv) / item_size;
Expand All @@ -759,21 +770,14 @@ size_t SchedulerRuntimeInfo::getInnerDimVectorizableWidth(TensorView* tv) {
// innermost dimension size for the word size of vectorizaiton
size_t vector_size = 1;
size_t next_vector_size = 2;

// Get the inner dimension size. If the inner dimension is contiguous, then
// expand left until it is no longer contiguous
size_t inner_dimension_size = 1;
for (int pos = id_pos; pos > id_pos_discontig; pos--) {
auto maybe_inner_dimension_size =
expression_evaluator_->evaluate(tv->axis(pos)->extent());
TORCH_INTERNAL_ASSERT(maybe_inner_dimension_size.has_value());
inner_dimension_size *= maybe_inner_dimension_size->as<int64_t>();
}
auto maybe_inner_dimension_size =
expression_evaluator_->evaluate(inner_most_dim->extent());
TORCH_INTERNAL_ASSERT(maybe_inner_dimension_size.has_value());
size_t inner_dimension_size = maybe_inner_dimension_size->as<int64_t>();

while (next_vector_size <= max_vector_size &&
next_vector_size <= inner_dimension_size &&
inner_dimension_size % next_vector_size == 0 &&
inner_dimension_stride % next_vector_size == 0) {
inner_dimension_size % next_vector_size == 0) {
vector_size = next_vector_size;
next_vector_size *= 2;
}
Expand Down
28 changes: 19 additions & 9 deletions torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,7 @@ size_t collectMaxVectorizeSizeWithContigMerge(
return vector_size;
}

//! Attempt to expand vectorized domains to contig merged domains. Break point
//! identifies the point in which you can't propagate contiguous merges. For
//! example in pointwise this is the point where we want to split the
//! parallelization to take advantage of broadcast, and for reduction
//! schedulers it's the point where we switch from a reduction domain to an
//! iter domain (or vice versa).
size_t expandVectorizationToContigMergedDomains(
size_t adjustVectorizationToContigMergedDomains(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
const std::vector<TensorView*> vectorizable_inputs_outputs,
Expand All @@ -177,6 +171,7 @@ size_t expandVectorizationToContigMergedDomains(
size_t max_expand_size = SchedulerRuntimeInfo::max_alignment_size_in_byte;
size_t common_alignment_size =
SchedulerRuntimeInfo::max_alignment_size_in_byte;
std::cout << "default_word_size: " << default_word_size << std::endl;

for (auto inp_out : vectorizable_inputs_outputs) {
auto dtype_size = dataTypeSize(
Expand All @@ -187,15 +182,18 @@ size_t expandVectorizationToContigMergedDomains(
SchedulerRuntimeInfo::max_alignment_size_in_byte / dtype_size);
max_expand_size = std::min(
max_expand_size, runtime_info.getMaxVectorizableWidth(inp_out));
std::cout << "runtime_info.getMaxVectorizableWidth(inp_out): "
<< runtime_info.getMaxVectorizableWidth(inp_out) << std::endl;
std::cout << "max_expand_size: " << max_expand_size << std::endl;
common_alignment_size =
std::min(common_alignment_size, runtime_info.getAlignmentSize(inp_out));
}

// If there's no possibility to increase vector size of provided tensors,
// then don't bother doing a more complex analysis to try and do so, just
// return early.
if (max_expand_size == default_word_size) {
return default_word_size;
if (max_expand_size <= default_word_size) {
return max_expand_size;
}

auto ca_map = ComputeAtMap(fusion);
Expand All @@ -205,6 +203,8 @@ size_t expandVectorizationToContigMergedDomains(
const int num_merged_domains =
static_cast<int>(ref_root.size()) - static_cast<int>(break_point);

std::cout << "num_merged_domains: " << num_merged_domains << std::endl;

// No expansion with no merged domain
if (num_merged_domains == 0) {
return default_word_size;
Expand All @@ -228,6 +228,8 @@ size_t expandVectorizationToContigMergedDomains(

cleanUpInnermostMergedDomains(ref_root, merged_domain);

std::cout << "word_size: " << word_size << std::endl;

// Stop if the reference doesn't get a larger word size.
if (word_size <= default_word_size) {
return default_word_size;
Expand Down Expand Up @@ -255,14 +257,18 @@ size_t expandVectorizationToContigMergedDomains(
++tv_num_merged_domains;
}
}
std::cout << "tv: " << tv->toString() << ", tv_num_merged_domains: " << tv_num_merged_domains
<< std::endl;

size_t tv_word_size = 1;
if (tv_num_merged_domains > 1) {
auto tv_merged_domain =
mergeInnermostDomains(tv_root, tv_num_merged_domains);
if (tv_merged_domain == nullptr) {
tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv);
std::cout << "tv_merged_domain == nullptr" << std::endl;
} else {
std::cout << "tv_merged_domain != nullptr" << std::endl;
tv_word_size = collectMaxVectorizeSizeWithContigMerge(
tv,
tv_merged_domain,
Expand All @@ -274,10 +280,14 @@ size_t expandVectorizationToContigMergedDomains(
} else {
tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv);
}
std::cout << "tv_word_size: " << tv_word_size << std::endl;

word_size = std::min(word_size, tv_word_size);
}

if (word_size >= max_expand_size) {
return max_expand_size;
}
return word_size;
}

Expand Down
18 changes: 11 additions & 7 deletions torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@ IterDomain* mergeInnermostDomains(
const std::vector<IterDomain*>& domain,
int num_merged_domains);

//! Attempt to expand vectorized domains to contig merged domains. Break point
//! identifies the point in which you can't propagate contiguous merges. For
//! example in pointwise this is the point where we want to split the
//! parallelization to take advantage of broadcast, and for reduction schedulers
//! it's the point where we switch from a reduction domain to an iter domain (or
//! vice versa).
size_t expandVectorizationToContigMergedDomains(
//! Adjust the vectorization sizes. It does the following things:
//! 1. It checks the strides of all discontiguous dimensions to make sure that
//! the provided vectorization size is allowed. If not allowed, adjust it
//! down.
//! 2. Attempt to expand vectorized domains to contig merged domains. Break
//! point identifies the point in which you can't propagate contiguous
//! merges. For example in pointwise this is the point where we want to split
//! the parallelization to take advantage of broadcast, and for reduction
//! schedulers it's the point where we switch from a reduction domain to an
//! iter domain (or vice versa).
size_t adjustVectorizationToContigMergedDomains(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
const std::vector<TensorView*> vectorizable_inputs_outputs,
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26191,6 +26191,11 @@ TEST_F(NVFuserTest, FusionVectorizeStrideContiguity3D_CUDA) {
at::Tensor t0 = at::randn({1000000, size, 3}, options).narrow(1, 0, 8);
auto cg_outputs = fec.runFusionWithInputs({t0});

std::cout << "shape: " << t0.sizes() << std::endl;
std::cout << "stride: " << t0.strides() << std::endl;
std::cout << "expect: " << vec << std::endl;
std::cout << "getVecSizeForPointwise(fec): " << getVecSizeForPointwise(fec) << std::endl;

TORCH_CHECK(getVecSizeForPointwise(fec) == vec);

testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
Expand Down

0 comments on commit 31e8d0a

Please sign in to comment.