Skip to content

Commit

Permalink
Redundant thread compute analysis to avoid un-necessary sync insertion (
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong authored Jul 19, 2022
1 parent 942be5b commit b7a4d93
Show file tree
Hide file tree
Showing 6 changed files with 494 additions and 17 deletions.
15 changes: 14 additions & 1 deletion torch/csrc/jit/codegen/cuda/fusion.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/codegen.h>
#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
Expand Down Expand Up @@ -516,7 +517,19 @@ std::vector<Val*> Fusion::usedMathVals() {
return used_math_vals;
}

std::unordered_set<Expr*> Fusion::unordered_uses(Val* val) const {
std::vector<Val*> Fusion::terminatingMathVals() {
VectorOfUniqueEntries<Val*> result;
auto used_vals = usedMathVals();
for (auto v : used_vals) {
// Locate the vals that are not expr outputs but have valid definitions.
if (unordered_uses(v).empty() && v->definition() != nullptr) {
result.pushBack(v);
}
}
return result.vector();
}

std::unordered_set<Expr*> Fusion::unordered_uses(const Val* val) const {
return std::unordered_set<Expr*>(val->uses().begin(), val->uses().end());
}

Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/jit/codegen/cuda/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,16 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer {
//! also included as they must show up in the final code.
std::vector<Val*> usedMathVals();

//! Returns all vals that are produced by used math expressions and
//! also do not have further consumers.
//!
//! In the case of an active multi-output expressions, the returned vector
//! will include the expression outputs that did not lead to an fusion
//! output.
std::vector<Val*> terminatingMathVals();

//! Return all Exprs that use val
std::unordered_set<Expr*> unordered_uses(Val* val) const;
std::unordered_set<Expr*> unordered_uses(const Val* val) const;

//! Return the Expr that produces val
Expr* definition(const Val* val) const;
Expand Down
37 changes: 23 additions & 14 deletions torch/csrc/jit/codegen/cuda/lower_sync_information.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ void SyncMap::build(Fusion* fusion) {
->threadPredMap()
.getPredicateInfo(producer)
.redundant_types;
// Get the parallel types that are inactive in consumer's use chains.
auto producer_redundant_use_types = GpuLower::current()
->threadPredMap()
.getPredicateInfo(producer)
.redundant_use_types;

// In sync info pass we only consider the parallel types in
// producer that are redundantly produced but not redundantly consumed.
producer_redundant_types =
producer_redundant_types & (~producer_redundant_use_types);

for (const auto producer_i : c10::irange(producer->nDims())) {
auto producer_axis = producer->axis(producer_i);
Expand Down Expand Up @@ -205,25 +215,24 @@ void SyncMap::build(Fusion* fusion) {
continue;
}

auto parallel_type_i = getParallelTypeBitMapOffset(parallel_type);

auto p_id = producer_parallel_ids[parallel_type_i];
auto c_id = consumer_parallel_ids[parallel_type_i];

// If consumer is parallelized with this type but producer is
// predicated redundant on this type. This parallel dimension
// is a RAW dimension. See test: FusionSeriaSmemWriteParallelRead1/2
//
// Even if consumer is not parallelized with this type, would still
// need a raw sync unless all use chain of the producer end with an
// output with the same redundant type.
// TODO: need a separate pass to detect the case where no raw sync
// is needed in this case, i.e. all use-def chains are redundant.
// In the case when the parallel id's are mapped by ca map,
// will additionally need to consider if the producer is
// a redundant write. The raw dim can be skipped only if
// consumer use chains only contain redundant uses.
// TODO:
// still losing a bit precision here for expr ordering
// sensitive cases, but we could wait until that becomes
// a perf limiter to fix.
if (producer_redundant_types.get(parallel_type)) {
raw_dims.set(parallel_type);
continue;
}

auto parallel_type_i = getParallelTypeBitMapOffset(parallel_type);

auto p_id = producer_parallel_ids[parallel_type_i];
auto c_id = consumer_parallel_ids[parallel_type_i];

if (p_id == nullptr && c_id == nullptr) {
continue;
} else if (p_id != nullptr && c_id != nullptr) {
Expand Down
205 changes: 205 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,184 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {
}
}

namespace {

//! A simple backward data flow pass:
//! This pass propagates information backward to annotate "redundant use
//! chain"'s.
//! The reason this is needed is that, say for example, if we have a chain
//! of register-to-register ops that begins with a redundant shared mem write
//! and ends with an op that non-redundantly uses the result, we'd need to
//! insert a sync at the begining of the register-to-register chain.
//!
//! The same mechanism also applies in the case of a register/sharedmem chain
//! that starts and ends with global memory read/write.
//!
//! The propagation rule is summarized as follows:
//!
//! Shared TV val:
//! Reset all block redundant info to its own redundant write info
//! Backpropagate grid redundant info
//! Global TV val:
//! Reset all redundant info to its own redundant write info
//! Local Tv val:
//! Backpropagate all redundant info
//! Exprs:
//! Propagate redundant info backwards from outputs to inputs:
//! For each parallel type,
//! The parallel type is redundantly used in the expr input
//! only if all of the outputs redundantly use the same type.
class RedundantUseAnalysis : BackwardVisitor {
public:
RedundantUseAnalysis(Fusion* fusion, const ThreadPredicateMap& pred_map)
: fusion_(fusion), pred_map_(pred_map) {
traverseFrom(fusion, fusion->terminatingMathVals());
}

//! Returns a bit map signifying the parallel dimensions
//! on which the given tv is redundantly used. On these
//! dimensions not all threads/blocks are required to
//! hold valid value for their dependent computations.
ParallelTypeBitmap getRedundantUseBitMap(const TensorView* tv) {
// Since all tv's consumers are visited at this point, we
// can aggregate the final redundant use info for this tv.
if (fusion_->unordered_uses(tv).empty()) {
// Base case, un-used is also not redundantly used
return ParallelTypeBitmap();
} else {
// Aggregate redundant use as a conjunction of all
// consumer's redundant consumer info propagated
// backward from their consumer chains.
ParallelTypeBitmap redundant_use;
redundant_use.setAllBID();
redundant_use.setAllTID();
for (auto expr : fusion_->unordered_uses(tv)) {
redundant_use &= redundant_expr_use_map_.at(expr);
}

return redundant_use;
}
}

private:
using BackwardVisitor::handle;

void handle(TensorView* tv) final {
auto redundant_tv_map = pred_map_.getPredicateInfo(tv).redundant_types;

// Setup the info to propagate backward for the producer tv's and
// expressions.
ParallelTypeBitmap& redundant_consumer_map =
redundant_consumer_parallel_type_map_[tv];

// Initialize the use map to the redundant pred result
redundant_consumer_map = redundant_tv_map;

if (tv->getMemoryType() == MemoryType::Shared) {
backPropagateRedundantUse(
redundant_consumer_map,
tv,
false, // no propagate TID redundant use for shared tv
true // propagate BID redundant use
);

} else if (tv->getMemoryType() == MemoryType::Local) {
backPropagateRedundantUse(
redundant_consumer_map,
tv,
true, // propagate TID redundant use
true // propagate BID redundant use
);
}
}

void backPropagateRedundantUse(
ParallelTypeBitmap& use_map,
TensorView* tv,
bool propagate_tid,
bool propagate_bid) {
// Clear the propagated part of the original result
if (propagate_bid) {
use_map.setAllBID();
}
if (propagate_tid) {
use_map.setAllTID();
}

for (auto expr : fusion_->unordered_uses(tv)) {
// Assuming all consumer expressions have been
// visited at this point since we are traversing
// backward.
auto expr_use_map = redundant_expr_use_map_.at(expr);
// Clear the part of expression use map that does not
// need to be propagated.
if (!propagate_bid) {
expr_use_map.setAllBID();
}
if (!propagate_tid) {
expr_use_map.setAllTID();
}

// Accumulate expression redundant usage
// This implements the `only if all` part in
// the discussion above.
use_map &= expr_use_map;
}
}

void handle(Expr* expr) final {
if (ir_utils::isTvOp(expr)) {
// Initialize redundant info for current expr
c10::optional<ParallelTypeBitmap> maybe_expr_pred_map;

for (auto consumer_tv :
ir_utils::filterByType<TensorView>(expr->outputs())) {
auto tv_redundant_bitmap =
redundant_consumer_parallel_type_map_.at(consumer_tv);

if (maybe_expr_pred_map.has_value()) {
// Accumulate redundant info of this tv output.
maybe_expr_pred_map.value() &= tv_redundant_bitmap;
} else {
// Copy the tv's redundant info as the first valid case.
maybe_expr_pred_map = tv_redundant_bitmap;
}
}

TORCH_INTERNAL_ASSERT(
maybe_expr_pred_map.has_value(), "TV op not having a tv output");
redundant_expr_use_map_[expr] = maybe_expr_pred_map.value();
}
}

private:
// Populated redundant use information on the used tv's
// This map provides information on if the given tv does not require
// valid data from its producer on any parallel dimensions.
// For example:
// T1_local = T0_shared[...]
// if(tid.x == 0)
// T2_shared[...] = T1_local[...]
// Then tidx would be redundant consumer parallel type
// for T1, as T1 is local tensor, and only threads satisfying
// tidx == 0 would need to provide a valid data.
// In this case, not all threads would need to read correct data
// from T0_shared, which would help remove some sync's.
std::unordered_map<const TensorView*, ParallelTypeBitmap>
redundant_consumer_parallel_type_map_;

// Populated redundant use information on the used tv expressions.
std::unordered_map<const Expr*, ParallelTypeBitmap> redundant_expr_use_map_;

// Short cut to the owning fusion of this analysis.
Fusion* fusion_ = nullptr;

// Short cut to the active pred map analysis this pass is running as part of.
const ThreadPredicateMap& pred_map_;
};

} // namespace

void ThreadPredicateMap::build(Fusion* fusion) {
FUSER_PERF_SCOPE("GpuLower::Lower::ThreadPredicateMap");

Expand All @@ -298,6 +476,15 @@ void ThreadPredicateMap::build(Fusion* fusion) {
updateBitSet(expr);
}
updated_tvs_.clear();
populateRedundantUseMap(fusion);
}

void ThreadPredicateMap::populateRedundantUseMap(Fusion* fusion) {
RedundantUseAnalysis redundant_use(fusion, *this);
for (auto& it : thread_predicates_) {
it.second.redundant_use_types =
redundant_use.getRedundantUseBitMap(it.first);
}
}

ThreadPredicateMap::const_iterator ThreadPredicateMap::find(
Expand Down Expand Up @@ -399,6 +586,23 @@ ParallelTypeBitmap ThreadPredicateMap::getParallelBroadcastDomains(
return parallel_broadcast & at(tv).limited_types;
}

ParallelTypeBitmap ThreadPredicateMap::getRedundantConsumerType(
Expr* expr) const {
c10::optional<ParallelTypeBitmap> result;
for (auto out_tv : ir_utils::filterByType<TensorView>(expr->outputs())) {
auto out_tv_redundant_map = getPredicateInfo(out_tv).redundant_use_types;
if (!result.has_value()) {
result = out_tv_redundant_map;
} else {
result.value() &= out_tv_redundant_map;
}
}

TORCH_INTERNAL_ASSERT(
result.has_value(), "ThreadPredicateMap : TV op assumed");
return result.value();
}

void ThreadPredicateMap::markAsUpdated(const TensorView* tv) {
updated_tvs_.insert(tv);
}
Expand All @@ -410,6 +614,7 @@ void ThreadPredicateMap::print() const {
std::cout << "T" << kv.first->name();
std::cout << " {" << kv.second.limited_types.toString() << "}\n";
std::cout << "{" << kv.second.redundant_types.toString() << "}\n";
std::cout << "{" << kv.second.redundant_use_types.toString() << "}\n";
}
std::cout << "--------------------------------\n\n";
}
Expand Down
24 changes: 23 additions & 1 deletion torch/csrc/jit/codegen/cuda/lower_thread_predicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,24 @@ class TORCH_CUDA_CU_API ThreadPredicateMap {
ParallelTypeBitmap limited_types;
// Parallel types where only one thread/block is enough.
ParallelTypeBitmap redundant_types;
// Tracking use chain of redundant writes:
// [Redundant use chain]
// a parallel type is a `redundant_consumer_type` only
// if all of its propagation use chains terminate with
// a redundant write of this type.
// A propagation use chain is currently either a reg-to-reg
// chain for a shared mem tv, or a reg/smem-to-reg/smem chain
// for a global tv.
// This is complementary information to `redundant_types`.
// If a tensor view is redundantly written and not redundantly
// used by all consumers, see FusionRedundantPredSync3,
// a RAW sync will need to be inserted before reading
// this redundantly written tensor.
ParallelTypeBitmap redundant_use_types;
bool operator==(const PredicateInfo& other) const {
return limited_types == other.limited_types &&
redundant_types == other.redundant_types;
redundant_types == other.redundant_types &&
redundant_use_types == other.redundant_use_types;
}
};

Expand Down Expand Up @@ -92,6 +107,9 @@ class TORCH_CUDA_CU_API ThreadPredicateMap {
static Bool* getPredicateFromPredicateInfo(
const ThreadPredicateMap::PredicateInfo& pred_info);

//! Get the redundant use types of the given expr, see [Redundant use chain]
ParallelTypeBitmap getRedundantConsumerType(Expr* expr) const;

private:
// Update the thread_predicates bitset based on provided Expr
void updateBitSet(const Expr*);
Expand All @@ -111,6 +129,10 @@ class TORCH_CUDA_CU_API ThreadPredicateMap {
//! Update a mapping
bool update(const TensorView* tv, const PredicateInfo& pred_and_src);

//! Backward populate redundant use chain info once the redundant
//! parallel writes have been identified.
void populateRedundantUseMap(Fusion* fusion);

private:
MapType thread_predicates_;
//! Keep track of updated tensors that need predicates to be computed
Expand Down
Loading

0 comments on commit b7a4d93

Please sign in to comment.