Skip to content

Commit

Permalink
Cleanup of lower_utils.cpp: Isolate out GpuLower usage (#1989)
Browse files Browse the repository at this point in the history
  • Loading branch information
csarofeen authored Sep 27, 2022
1 parent 15f2f6d commit fcf8c09
Show file tree
Hide file tree
Showing 13 changed files with 313 additions and 265 deletions.
20 changes: 14 additions & 6 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ Val* getConcreteProducerOffsetWithGather(
Val* window_idx = nullptr;

if (use_concrete_map) {
window_idx = index_map.at(ir_utils::caMapExactConcreteId(window_id));
window_idx = index_map.at(GpuLower::current()->caMap()->getConcreteMappedID(
window_id, IdMappingMode::EXACT));
} else {
window_idx = index_map.at(window_id);
}
Expand Down Expand Up @@ -703,7 +704,9 @@ void IndexCompute::collectIndexIntoPermissiveMap(
auto id_outputs = ir_utils::filterByType<IterDomain>(expr->outputs());
if (std::all_of(
id_outputs.begin(), id_outputs.end(), [this](IterDomain* id) {
return index_map_.count(ir_utils::caMapExactConcreteId(id));
return index_map_.count(
GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT));
})) {
// Visit this expression:
// LoopIndexingAnalysis::traverseFromDomainVals made sure that each
Expand All @@ -715,7 +718,9 @@ void IndexCompute::collectIndexIntoPermissiveMap(
for (auto id : id_inputs) {
// Collect backward pass results from this expression if they are
// made available in by this expression.
auto idx_it = index_map_.find(ir_utils::caMapExactConcreteId(id));
auto idx_it =
index_map_.find(GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT));

if (idx_it != index_map_.end()) {
permissive_index_map_
Expand All @@ -730,7 +735,8 @@ void IndexCompute::collectIndexIntoPermissiveMap(
void IndexCompute::updateIndexMapFromPermissiveMap(const Expr* id_expr) {
auto id_outputs = ir_utils::filterByType<IterDomain>(id_expr->outputs());
for (auto id : id_outputs) {
auto concrete_id = ir_utils::caMapExactConcreteId(id);
auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT);
// Only try to copy index val from permissive map when
// the index is missing.
if (!index_map_.count(concrete_id)) {
Expand Down Expand Up @@ -1506,7 +1512,8 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
// effort which means some domains may be producer's original domains.
std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup;
for (auto entry : c2p_map) {
auto ref_id = ir_utils::caMapExactConcreteId(entry.first);
auto ref_id = GpuLower::current()->caMap()->getConcreteMappedID(
entry.first, IdMappingMode::EXACT);
auto p_id = entry.second;
if (ref_id->getParallelType() == ParallelType::Vectorize) {
p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType()));
Expand Down Expand Up @@ -1745,7 +1752,8 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
// effort which means some domains may be the originals.
std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup;
for (auto entry : c2p_index_map) {
auto ref_id = ir_utils::caMapExactConcreteId(entry.first);
auto ref_id = GpuLower::current()->caMap()->getConcreteMappedID(
entry.first, IdMappingMode::EXACT);
auto p_id = entry.second;
if (ref_id->getParallelType() == ParallelType::Vectorize) {
p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType()));
Expand Down
10 changes: 8 additions & 2 deletions torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ namespace fuser {
namespace cuda {

namespace {
// Alias used for std::transform
IterDomain* exactConcreteId(IterDomain* id) {
return GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT);
}

//! Checks that the current loop nest is not realizing a serial
//! broadcast so that each index of producer buffer will only
Expand Down Expand Up @@ -83,7 +88,7 @@ bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) {
std::inserter(
producer_exact_concrete_root_ids,
producer_exact_concrete_root_ids.begin()),
ir_utils::caMapExactConcreteId);
exactConcreteId);

// Check if serial loop roots indexes any exact root id's that
// is not within the set of producer's root exact id's. These
Expand All @@ -92,7 +97,8 @@ bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) {
for (auto serial_loop_root :
ir_utils::filterByType<IterDomain>(serial_loop_roots)) {
if (!producer_exact_concrete_root_ids.count(
ir_utils::caMapExactConcreteId(serial_loop_root))) {
GpuLower::current()->caMap()->getConcreteMappedID(
serial_loop_root, IdMappingMode::EXACT))) {
return true;
}
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class AllocationInserter : public kir::ExprMutator {
// info.init_place_before, info.alloc_for_loop, info.alloc_place_before
void fillAllocationInformation(AllocationInformation& info, Expr* expr) {
auto loop_alloc_info =
loop_utils::getAllocInformation(info.buffer, for_loops_);
lower_loop_utils::getAllocInformation(info.buffer, for_loops_);

info.init_for_loop = loop_alloc_info.init_for_loop;
info.alloc_for_loop = loop_alloc_info.alloc_for_loop;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ std::vector<IterDomain*> getLocalDomainOrdering(
std::sort(
merged_domain.begin(),
merged_domain.end(),
IterDomainDependencySorter(
ir_utils::IterDomainDependencySorter(
concrete_id_dependencies, GpuLower::current()->caMap()));
return merged_domain;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ kir::Allocate* IndexLowering::allocateUniqueBuffer(

// No existing allocation found. Create a new one
auto new_buffer =
ir_utils::allocGlobalBufferForGridComm(buffer_size, dtype, zero_init);
lower_utils::allocGlobalBufferForGridComm(buffer_size, dtype, zero_init);

// Keep track of the allocation
alloc_map.emplace(out_tv, new_buffer);
Expand Down
78 changes: 54 additions & 24 deletions torch/csrc/jit/codegen/cuda/lower_index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ IndexingParameters getLinearIndexParameters(

for (auto loop_idx : c10::irange(loops.size())) {
auto loop = loops[loop_idx];
auto index_domain = ir_utils::caMapExactConcreteId(loop_domain[loop_idx]);
auto index_domain = GpuLower::current()->caMap()->getConcreteMappedID(
loop_domain[loop_idx], IdMappingMode::EXACT);
if (loop->isTrivial()) {
// This is useful information in the case of
// MisalignedVectorize and double buffer epilog, etc.
Expand Down Expand Up @@ -149,7 +150,9 @@ IndexingParameters getLinearIndexParameters(

auto loop_id = loop_indexing.loopDomains()[loop_idx];

auto concrete_loop_id = ir_utils::caMapExactConcreteId(loop_id);
auto concrete_loop_id =
GpuLower::current()->caMap()->getConcreteMappedID(
loop_id, IdMappingMode::EXACT);

auto stage_depth =
GpuLower::current()->doubleBufferInfo().getStageDepthFor(
Expand Down Expand Up @@ -186,7 +189,7 @@ IndexingParameters getNonGlobalInitialIndexParameters(
}

auto alloc_tv = index_producer ? producer_tv : consumer_tv;
auto alloc_info = loop_utils::getAllocInformation(
auto alloc_info = lower_utils::getAllocInformation(
alloc_tv, loops, alloc_id_map, index_producer);

std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
Expand Down Expand Up @@ -217,7 +220,9 @@ IndexingParameters getNonGlobalInitialIndexParameters(
auto loop = loops[loop_idx];
auto loop_domain = loop_domains[loop_idx];

auto concrete_loop_domain = ir_utils::caMapExactConcreteId(loop_domain);
auto concrete_loop_domain =
GpuLower::current()->caMap()->getConcreteMappedID(
loop_domain, IdMappingMode::EXACT);

index_parameters.initial_concrete_id_index[concrete_loop_domain] =
loop_to_ind_map.at(loop);
Expand Down Expand Up @@ -399,7 +404,8 @@ IndexingParameters getPredicateInitialIndexParameters(
for (int loop_idx : c10::irange(loops.size())) {
auto loop = loops.at(loop_idx);
auto concrete_loop_domain =
ir_utils::caMapExactConcreteId(loop_domains.at(loop_idx));
GpuLower::current()->caMap()->getConcreteMappedID(
loop_domains.at(loop_idx), IdMappingMode::EXACT);
index_parameters.initial_concrete_id_index[concrete_loop_domain] =
loop_to_ind_map.at(loop);
}
Expand Down Expand Up @@ -566,7 +572,10 @@ LoopIndexingAnalysis::LoopIndexingAnalysis(
// consume each concrete id once so this map is well defined.
for (auto expr : replayed_exprs_) {
for (auto input_id : ir_utils::filterByType<IterDomain>(expr->inputs())) {
concrete_id_to_consumer_[ir_utils::caMapExactConcreteId(input_id)] = expr;
auto concrete_input_id =
GpuLower::current()->caMap()->getConcreteMappedID(
input_id, IdMappingMode::EXACT);
concrete_id_to_consumer_[concrete_input_id] = expr;
}
}

Expand Down Expand Up @@ -598,7 +607,8 @@ void LoopIndexingAnalysis::validateLoopStructure(
for (auto it_i = loops.begin(); it_i != loops.end(); ++it_i) {
// Largely duplicating original logic
auto loop_id = (*it_i)->iter_domain();
auto concrete_loop_id = ir_utils::caMapExactConcreteId(loop_id);
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
loop_id, IdMappingMode::EXACT);

TORCH_INTERNAL_ASSERT(
!concrete_to_loop.count(concrete_loop_id),
Expand Down Expand Up @@ -662,13 +672,22 @@ void LoopIndexingAnalysis::traverseFromDomainVals() {
}

IterDomain* LoopIndexingAnalysis::concretizeAndVisitId(IterDomain* id) {
auto concrete_id = ir_utils::caMapExactConcreteId(id);
auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT);
if (replayed_concrete_ids_.pushBack(concrete_id)) {
concrete_to_original_id_[concrete_id] = id;
}
return concrete_id;
}

namespace {
// Alias used for std::transform
IterDomain* exactConcreteId(IterDomain* id) {
return GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT);
}
} // namespace

void LoopIndexingAnalysis::visitExpr(Expr* expr) {
if (auto swizzle2d = dynamic_cast<Swizzle2D*>(expr)) {
// Swizzle outputs are already forwarded through
Expand Down Expand Up @@ -703,14 +722,14 @@ void LoopIndexingAnalysis::visitExpr(Expr* expr) {
consumed_ids.begin(),
consumed_ids.end(),
std::inserter(consumed_concrete_, consumed_concrete_.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);

auto produced_ids = ir_utils::filterByType<IterDomain>(expr->outputs());
std::transform(
produced_ids.begin(),
produced_ids.end(),
std::inserter(produced_concrete_, produced_concrete_.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);
}

bool LoopIndexingAnalysis::visitIdsAndCheckDuplication(
Expand Down Expand Up @@ -800,7 +819,8 @@ void LoopIndexingAnalysis::constructLoopDomains() {
// will complain for not having all outputs of the traversal.
for (auto id : ir_utils::filterByType<IterDomain>(all_ids_from_root)) {
if (id->uses().empty()) {
loop_domains_.pushBack(ir_utils::caMapExactConcreteId(id));
loop_domains_.pushBack(GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT));
}
}
}
Expand Down Expand Up @@ -880,7 +900,8 @@ IndexFromIdGraph getTensorIndexFromIdGraph(

// Exact id will have to be pulled from consumer side as the
// producer side are replayed ids.
auto exact_concrete_id = ir_utils::caMapExactConcreteId(consumer_id);
auto exact_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
consumer_id, IdMappingMode::EXACT);

index_update_map[exact_concrete_id] = target_id;

Expand Down Expand Up @@ -961,7 +982,8 @@ IndexFromIdGraph getPredicateIndexingFromIdGraph(
ir_utils::filterByType<IterDomain>(all_consumer_vals)) {
// Track the non-concrete id we were trying to bind index
// to, whether from producer or consumer.
auto exact_concrete_id = ir_utils::caMapExactConcreteId(consumer_id);
auto exact_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
consumer_id, IdMappingMode::EXACT);
index_update_map[exact_concrete_id] = consumer_id;
}

Expand Down Expand Up @@ -1040,7 +1062,8 @@ LoopIndexingTraversal::LoopIndexingTraversal(
auto next_ids =
ir_utils::filterByType<IterDomain>(nextValsInTraversalOrder(expr));
for (auto id : next_ids) {
auto concrete_id = ir_utils::caMapExactConcreteId(id);
auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT);
TORCH_INTERNAL_ASSERT(
concrete_id_to_dependency_.insert(std::make_pair(concrete_id, expr))
.second,
Expand Down Expand Up @@ -1108,7 +1131,8 @@ std::vector<Expr*> LoopIndexingTraversal::getExprList() {
for (auto prev_id :
ir_utils::filterByType<IterDomain>(prevValsInTraversalOrder(top))) {
auto prev_expr_it = concrete_id_to_dependency_.find(
ir_utils::caMapExactConcreteId(prev_id));
GpuLower::current()->caMap()->getConcreteMappedID(
prev_id, IdMappingMode::EXACT));
if (prev_expr_it != concrete_id_to_dependency_.end()) {
auto prev_expr = prev_expr_it->second;
if (!visited.count(prev_expr)) {
Expand Down Expand Up @@ -1145,7 +1169,7 @@ void LoopIndexingAnalysis::collectOutOfLineExprs() {
consumer_tv_->getComputeAtPosition(),
consumer_tv_->domain()->domain().end(),
std::inserter(out_of_line_ids, out_of_line_ids.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);

// Get the original selected list of index expressions
// in reverse topological order.
Expand All @@ -1160,7 +1184,9 @@ void LoopIndexingAnalysis::collectOutOfLineExprs() {
id_outputs.begin(),
id_outputs.end(),
[&out_of_line_ids](IterDomain* id) {
return out_of_line_ids.count(ir_utils::caMapExactConcreteId(id));
return out_of_line_ids.count(
GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::EXACT));
})) {
// Record out of line expression
out_of_line_exprs_.push_back(expr);
Expand All @@ -1171,7 +1197,7 @@ void LoopIndexingAnalysis::collectOutOfLineExprs() {
id_inputs.begin(),
id_inputs.end(),
std::inserter(out_of_line_ids, out_of_line_ids.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);
}
}
}
Expand All @@ -1192,14 +1218,14 @@ std::unordered_set<IterDomain*> LoopIndexing::getAllExactConcreteIdSet() const {
out_ids.begin(),
out_ids.end(),
std::inserter(all_id_set, all_id_set.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);

auto in_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
std::transform(
in_ids.begin(),
in_ids.end(),
std::inserter(all_id_set, all_id_set.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);
}
return all_id_set;
}
Expand Down Expand Up @@ -1244,7 +1270,9 @@ class LoopIndexingPreferredPathCompute : public IterVisitor {
}
mapped_id = c_id_it->second;
}
auto concrete_original_id = ir_utils::caMapExactConcreteId(mapped_id);
auto concrete_original_id =
GpuLower::current()->caMap()->getConcreteMappedID(
mapped_id, IdMappingMode::EXACT);
if (all_concrete_ids.count(concrete_original_id)) {
if (original_id->isBroadcast() || original_id->isReduction() ||
original_id->isStride()) {
Expand All @@ -1270,16 +1298,18 @@ class LoopIndexingPreferredPathCompute : public IterVisitor {
all_iter_inputs.begin(),
all_iter_inputs.end(),
[&](IterDomain* inp_id) {
return this->preferred_path_.find(ir_utils::caMapExactConcreteId(
inp_id)) != this->preferred_path_.end();
return this->preferred_path_.find(
GpuLower::current()->caMap()->getConcreteMappedID(
inp_id, IdMappingMode::EXACT)) !=
this->preferred_path_.end();
})) {
auto all_iter_outputs = ir_utils::filterByType<IterDomain>(e->outputs());

std::transform(
all_iter_outputs.begin(),
all_iter_outputs.end(),
std::inserter(preferred_path_, preferred_path_.end()),
ir_utils::caMapExactConcreteId);
exactConcreteId);
}
}

Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class WarSyncInserter : private kir::ExprMutator {
auto maybe_aliased_tv = alloc_map_.getRealBuffer(tv);
auto alloc_it = smem_allocations_.find(maybe_aliased_tv);
auto ca_loop =
loop_utils::getAllocInformation(tv, for_loops_).init_for_loop;
lower_utils::getAllocInformation(tv, for_loops_).init_for_loop;
if (alloc_it == smem_allocations_.end()) {
WarMemoryInfo mem_info;
mem_info.ca_loop = ca_loop;
Expand Down Expand Up @@ -486,7 +486,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
Expr* sync_expr = nullptr;
kir::Allocate* maybe_alloc = nullptr;
if (sync_bitmap.hasBID()) {
maybe_alloc = ir_utils::allocGlobalBufferForGridComm(
maybe_alloc = lower_utils::allocGlobalBufferForGridComm(
getGridSyncBufferSize(sync_bitmap), DataType::Int, true);
sync_expr = IrBuilder::create<kir::GridSync>(
sync_bitmap, maybe_alloc->buffer());
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ void LoopNestGenerator::generate(const std::vector<Expr*>& exprs) {
std::sort(
loop_structure.rbegin(),
loop_structure.rend(),
IterDomainDependencySorter(
ir_utils::IterDomainDependencySorter(
concrete_id_dependencies, GpuLower::current()->caMap()));
loop_structures_[tv] = loop_structure;
}
Expand Down
Loading

0 comments on commit fcf8c09

Please sign in to comment.