Skip to content

Commit

Permalink
Indexing refactor stage 2 : Remove reference tensor in predicate inde…
Browse files Browse the repository at this point in the history
…xing logic (csarofeen#1784)

Co-authored-by: Christian Sarofeen <csarofeen@nvidia.com>
  • Loading branch information
shmsong and csarofeen authored Jul 2, 2022
1 parent f008140 commit 8d384da
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 65 deletions.
89 changes: 31 additions & 58 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2813,7 +2813,6 @@ auto getPredicateReferenceIndexing(
std::pair<Val*, Val*> getStartAndStopOffsets(
IterDomain* consumer_id,
TensorView* consumer_tv,
const ReferenceTensor& reference,
const std::unordered_map<IterDomain*, Val*>& consumer_start_index_map,
const std::unordered_map<IterDomain*, Val*>& consumer_stop_index_map,
bool padding_predicate,
Expand Down Expand Up @@ -2979,12 +2978,12 @@ std::pair<Val*, Val*> hoistPredicates(
Val* start_index,
Val* stop_index,
const std::vector<kir::ForLoop*>& loops,
std::vector<IterDomain*> loop_domains,
const std::unordered_map<IterDomain*, Val*> start_initial_loop_index_map,
const std::unordered_map<IterDomain*, Val*> stop_initial_loop_index_map,
kir::ForLoop* unswitch_or_vec_loop,
IterDomain* predicated_consumer_id,
TensorView* predicated_consumer_tv,
TensorDomain* ref_td,
const std::unordered_map<IterDomain*, Val*>& ref_start_index_map,
const std::unordered_map<IterDomain*, Val*>& ref_stop_index_map) {
TensorView* predicated_consumer_tv) {
const std::pair<Val*, Val*> same_indices{start_index, stop_index};

if (isDisabled(DisableOption::IndexHoist)) {
Expand All @@ -3004,8 +3003,8 @@ std::pair<Val*, Val*> hoistPredicates(
GpuLower::current()->commonIndexMap().insert(
predicated_consumer_id,
predicated_consumer_tv->domain(),
ref_td,
ref_stop_index_map,
loop_domains,
stop_initial_loop_index_map,
loops,
stop_index);
}
Expand All @@ -3021,8 +3020,8 @@ std::pair<Val*, Val*> hoistPredicates(
GpuLower::current()->commonIndexMap().insert(
predicated_consumer_id,
predicated_consumer_tv->domain(),
ref_td,
ref_start_index_map,
loop_domains,
start_initial_loop_index_map,
loops,
start_index);
}
Expand All @@ -3033,12 +3032,11 @@ std::pair<Val*, Val*> hoistPredicates(
} // namespace

// Returns predicates and the concrete (by loop map) root domains they cover
std::pair<std::vector<RootPredicateInfo>, ReferenceTensor> Index::
getReferenceRootPredicates(
TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops,
kir::ForLoop* unswitch_or_vec_loop,
bool shift_padding) {
std::vector<RootPredicateInfo> Index::getReferenceRootPredicates(
TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops,
kir::ForLoop* unswitch_or_vec_loop,
bool shift_padding) {
FUSER_PERF_SCOPE("GpuLower::Lower::Index::getReferenceRootPredicates");

const auto gpu_lower = GpuLower::current();
Expand All @@ -3047,22 +3045,9 @@ std::pair<std::vector<RootPredicateInfo>, ReferenceTensor> Index::

// Nothing needs to be done when padding is not required.
if (shift_padding && !needsPadding(consumer_tv)) {
return {{RootPredicateInfo::getFalseInfo()}, ReferenceTensor{}};
return {RootPredicateInfo::getFalseInfo()};
}

// Get a reference tensor replayed as existing loop structure
ReferenceTensor reference =
IndexReferenceReplay::getReference(loops, consumer_tv);

// Generate halo information for reference.
updateHaloInfoForReference(reference, consumer_tv);

const auto ref_2_consumer = indexMapReferenceTo(
consumer_tv, gpu_lower->caMap(), reference.concrete_to_id);

const auto reference_halo_extent_map =
getReferenceHaloExtentMap(reference, ref_2_consumer);

auto db_axis = gpu_lower->doubleBufferInfo().getDoubleBufferAxis(consumer_tv);

// Indexing is done without considering contig merging. Actual
Expand All @@ -3073,38 +3058,27 @@ std::pair<std::vector<RootPredicateInfo>, ReferenceTensor> Index::
std::vector<bool>(consumer_tv->getMaybeRFactorDomain().size(), false),
{});

// Generate start and stop indexing from idgraph.
//
// Both start and stop positions may need to be predicated. Indexing
// differs when generating predicates for unswitch.
// NOTE: If we could find-and-replace KIR nodes, we could just
// generate one index map, clone it and replace the loop-to-index
// mappings of unswitched loops for the start predicate.
auto ref_stop_indexing = getPredicateReferenceIndexing(
loops, reference, unswitch_or_vec_loop, db_axis, false);
const auto consumer_stop_indexing = ref_stop_indexing.updateIndexCompute(
consumer_tv->domain(),
ref_2_consumer,
contig_finder,
reference_halo_extent_map);

auto stop_indexing_from_idgraph = getPredicateIndexingFromIdGraph(
loops, consumer_tv, unswitch_or_vec_loop, db_axis, false);
const auto consumer_stop_indexing = stop_indexing_from_idgraph.index;
const auto& consumer_stop_index_map = consumer_stop_indexing.indexMap();

// If not unswitch, share the same indexing map as the stop index
// map
const auto& ref_start_indexing = is_unswitch
? getPredicateReferenceIndexing(
loops, reference, unswitch_or_vec_loop, db_axis, true)
: ref_stop_indexing;

std::unordered_map<IterDomain*, Val*> consumer_start_index_map;
if (is_unswitch) {
const auto consumer_start_indexing = ref_start_indexing.updateIndexCompute(
consumer_tv->domain(),
ref_2_consumer,
contig_finder,
reference_halo_extent_map);
consumer_start_index_map = consumer_start_indexing.indexMap();
} else {
consumer_start_index_map = consumer_stop_index_map;
}
const auto start_indexing_from_idgraph = is_unswitch
? getPredicateIndexingFromIdGraph(
loops, consumer_tv, unswitch_or_vec_loop, db_axis, true)
: stop_indexing_from_idgraph;
const auto consumer_start_indexing = start_indexing_from_idgraph.index;
const auto& consumer_start_index_map = consumer_start_indexing.indexMap();

// Get the contiguous ids we need to generate predicates for
auto contig_id_infos =
Expand Down Expand Up @@ -3161,7 +3135,6 @@ std::pair<std::vector<RootPredicateInfo>, ReferenceTensor> Index::
std::tie(info.start_offset_, info.stop_offset_) = getStartAndStopOffsets(
contig_id,
consumer_tv,
reference,
consumer_start_index_map,
consumer_stop_index_map,
shift_padding,
Expand All @@ -3175,12 +3148,12 @@ std::pair<std::vector<RootPredicateInfo>, ReferenceTensor> Index::
start_index,
stop_index,
loops,
stop_indexing_from_idgraph.resolved_loop_domains,
start_indexing_from_idgraph.initial_concrete_index_map,
stop_indexing_from_idgraph.initial_concrete_index_map,
unswitch_or_vec_loop,
contig_id,
consumer_tv,
reference.domain,
ref_start_indexing.indexMap(),
ref_stop_indexing.indexMap());
consumer_tv);

// Build predicates for start positions as:
// start_index + start_offset >= 0
Expand Down Expand Up @@ -3217,7 +3190,7 @@ std::pair<std::vector<RootPredicateInfo>, ReferenceTensor> Index::
pred_info_vec.emplace_back(info);
}

return {pred_info_vec, reference};
return pred_info_vec;
}

bool Index::protectWithMagicZero(
Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/jit/codegen/cuda/index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,7 @@ class Index {
//! this is not a bool value as if we have an unswitch loop with a vectorized
//! loop inside, we only want to base the "unswitch" like predicate on the
//! vectorized loop.
static std::pair<std::vector<RootPredicateInfo>, ReferenceTensor>
getReferenceRootPredicates(
static std::vector<RootPredicateInfo> getReferenceRootPredicates(
TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops,
kir::ForLoop* unswitch_or_vec_loop,
Expand Down
Loading

0 comments on commit 8d384da

Please sign in to comment.