Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 158 additions & 1 deletion torch/csrc/jit/codegen/cuda/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ bool IterDomainGraph::exprsMap(
return true;
}

// Given first and second Exprs "match"
// Expr type matches
// IterDomain's in the inputs and outputs exact match, (including argument
// position positions)
// Paramters like Split's factor "match" (exact match on integers could be
// better, as today it will just check it's the same symbol or evaluated to
// the same constant. However, we know all the extents of all the
// IterDomain's that exact map with eachother are the same value.
void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) {
if (first == nullptr || second == nullptr) {
return;
Expand Down Expand Up @@ -194,7 +202,37 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) {

namespace {

// Returns a pair of mapped IDs
// Returns the first pair of id's in ids detected to match eachother on the
// permissive map of the ID graph. TODO: what this is really looking for is if
// there's any overlapping between the iter domains in the provided set.
//
// i.e. if we have:
// tv0 = arange(6).view({3, 2})
// tv1 = tv0[3, 2].t()
// tv2 = tv0[3, 2].view({2, 3})
// tv3 = tv1 + tv2
//
// Then we can see this overlap in the tv3 expression as:
//
// tv0 = { {0, 1, 2},
// {3, 4, 5} }
//
// tv1 = { {0, 3},
// {1, 4},
// {2, 5} }
//
// tv2 = { {0, 1},
// {2, 3},
// {4, 5} }
//
// The elements in tv1 {3, 1, 4, 2}, map respectively to the elements in tv2 {1,
// 2, 3, 4}. The reason this is so important is it means that generating tv3 is
// no longer a trivially parallelizable problem (if we include the dag all the
// way to tv0). So tv0's axes cannot be inlined across both the tv0 and tv1
// path. This breaks some assumptions we have today in schedulers that will
// assume tv2 can be trivially inlined/parallelized. Instead we'd need to take
// into consideration the effective communication going on here, so that we pull
// multiple values of tv0 to compute tv3.
c10::optional<std::pair<IterDomain*, IterDomain*>> detectMappablePair(
const std::vector<IterDomain*>& ids,
const IterDomainGraph& id_graph) {
Expand Down Expand Up @@ -637,6 +675,7 @@ ComputeAtMap::ComputeAtMap(Fusion* fusion)
void ComputeAtMap::build(Fusion* fusion) {
trivial_reduction_info_.build(fusion);
buildConcreteIds();
buildUniqueExactExprMaps();
}

void ComputeAtMap::validateAndPropagatePType() {
Expand Down Expand Up @@ -1052,6 +1091,124 @@ void ComputeAtMap::buildConcreteIds() {
}
}

bool ComputeAtMap::areExactExprs(Expr* expr_1, Expr* expr_2) {
if (expr_1->getExprType() != expr_2->getExprType()) {
return false;
}

if (expr_1->isA<Swizzle2D>()) {
auto swizzle_1 = expr_1->as<Swizzle2D>();
auto swizzle_2 = expr_2->as<Swizzle2D>();
if (swizzle_1->swizzleType() != swizzle_2->swizzleType() ||
swizzle_1->swizzleMode() != swizzle_2->swizzleMode()) {
return false;
}
}

TORCH_INTERNAL_ASSERT(
expr_1->inputs().size() == expr_2->inputs().size() &&
expr_1->outputs().size() == expr_2->outputs().size(),
"Expr traversal doesn't support variable number of inputs and outputs.");

for (auto input_i : c10::irange(expr_1->inputs().size())) {
if (expr_1->inputs()[input_i]->isA<IterDomain>() &&
!areMapped(
expr_1->inputs()[input_i]->as<IterDomain>(),
expr_2->inputs()[input_i]->as<IterDomain>(),
IdMappingMode::EXACT)) {
// Inputs don't exact map in the right order
return false;
}
}

for (auto output_i : c10::irange(expr_1->outputs().size())) {
if (expr_1->outputs()[output_i]->isA<IterDomain>() &&
!areMapped(
expr_1->outputs()[output_i]->as<IterDomain>(),
expr_2->outputs()[output_i]->as<IterDomain>(),
IdMappingMode::EXACT)) {
// Outputs don't exact map in the right order
return false;
}
}
// Expr's are almost exact mapped transforms
return true;
}

void ComputeAtMap::buildUniqueExactExprMaps() {
// Start by building definitions
for (const auto& disjoint_set_shared_ptr :
id_graph_.exactNodes().disjointSets()) {
std::vector<Expr*> definitions;

// N^2 in number of unique transformations, this might be better to do
// when generating the map.
IterDomain* concrete_id = nullptr;
for (auto id : disjoint_set_shared_ptr->vector()) {
if (concrete_id == nullptr) {
concrete_id = getConcreteMappedID(id, IdMappingMode::EXACT);
}

if (id->definition() != nullptr) {
bool match = false;
for (auto recorded_def : definitions) {
if (areExactExprs(id->definition(), recorded_def)) {
match = true;
break;
}
}
if (!match) {
definitions.push_back(id->definition());
}
}
}
unique_exact_definitions_[disjoint_set_shared_ptr] = definitions;
}

// Use definitions to build uses
for (const auto& disjoint_set_shared_ptr :
id_graph_.exactNodes().disjointSets()) {
auto definition_it =
unique_exact_definitions_.find(disjoint_set_shared_ptr);

if (definition_it == unique_exact_definitions_.end()) {
continue;
}

const auto& definitions = definition_it->second;

for (auto definition : definitions) {
auto inp_ids = ir_utils::filterByType<IterDomain>(definition->inputs());
for (auto inp : inp_ids) {
auto inp_disjoint_set_shared_ptr =
disjointSetOf(inp, IdMappingMode::EXACT);
// Initialize uses entry
if (unique_exact_uses_.find(inp_disjoint_set_shared_ptr) ==
unique_exact_uses_.end()) {
unique_exact_uses_[inp_disjoint_set_shared_ptr] = {};
}

auto& uses = unique_exact_uses_.at(inp_disjoint_set_shared_ptr);

bool already_added = false;
for (auto other_use : uses) {
if (areExactExprs(definition, other_use)) {
already_added = true;
break;
}
}
if (already_added) {
continue;
}

if (!already_added) {
uses.push_back(definition);
}
}
}
}
}

IterDomain* ComputeAtMap::getConcreteMappedID(
IterDomain* id,
IdMappingMode mode) const {
Expand Down
58 changes: 54 additions & 4 deletions torch/csrc/jit/codegen/cuda/compute_at_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,31 @@ class TORCH_CUDA_CU_API ComputeAtMap {
//! guarenteed to return iter domains in the same disjoint set.
IterDomain* getConcreteMappedID(IterDomain* id, IdMappingMode mode) const;

//! Returns a list of expressions that produce the iter domains of all exact
//! mapped id's to 'id'. Expressions that are the same exact transformations
//! are deduplicated in the returned expressions.
std::vector<Expr*> uniqueExactDefinitions(IterDomain* id) const {
auto disjoint_set = disjointSetOf(id, IdMappingMode::EXACT);
auto unique_exact_definition_it =
unique_exact_definitions_.find(disjoint_set);
if (unique_exact_definition_it == unique_exact_definitions_.end()) {
return {};
}
return unique_exact_definition_it->second;
}

//! Returns a list of expressions that *use* the iter domains of all exact
//! mapped id's to 'id'. Expressions that are the same exact transformations
//! are deduplicated in the returned expressions.
std::vector<Expr*> uniqueExactUses(IterDomain* id) const {
auto disjoint_set = disjointSetOf(id, IdMappingMode::EXACT);
auto unique_exact_use_it = unique_exact_uses_.find(disjoint_set);
if (unique_exact_use_it == unique_exact_uses_.end()) {
return {};
}
return unique_exact_use_it->second;
}

// Prints mapping information, forwards to an internal IterDomainGraph
std::string toString() const;

Expand Down Expand Up @@ -211,6 +236,16 @@ class TORCH_CUDA_CU_API ComputeAtMap {
DoubleBufferLoopStage double_buffer_loop_stage =
DoubleBufferLoopStage::NotApplicable) const;

// Returns if expr_1 and expr_2 have exact mapped IterDomains in
// inputs/outputs (order matters) and if the expressions have matching
// parameters.
bool areExactExprs(Expr* expr_1, Expr* expr_2);

// Produce the disjoint set containing provided id with mapping mode.
const std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>& disjointSetOf(
IterDomain* id,
IdMappingMode mode) const;

private:
// Build id_graph_
void build(Fusion* fusion);
Expand All @@ -220,10 +255,8 @@ class TORCH_CUDA_CU_API ComputeAtMap {
IterDomain* computeConcreteId(IterDomain* id, IdMappingMode mode);
void buildConcreteIds();

// Produce the disjoint set containing provided id with mapping mode.
const std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>& disjointSetOf(
IterDomain* id,
IdMappingMode mode) const;
// Relies on concrete_id_cache_, buildConcreteIds() must be run before this.
void buildUniqueExactExprMaps();

// Should be built once and never modified again.
IterDomainGraph id_graph_;
Expand All @@ -239,6 +272,23 @@ class TORCH_CUDA_CU_API ComputeAtMap {
IterDomain*>
concrete_id_cache_;

// Unique expressions operating on exact disjoint set. For each IterDomain in
// each exact disjoint set will log its definition in the std::vector<Expr*>.
// If another expression is already in the set where inputs and outputs
// exactly match with the expression to add along with the other parameters of
// the transformation (like split's factor, or swizzles types) then the
// expression will not be added as it would be an "duplicate" transformation.
std::unordered_map<
std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>,
std::vector<Expr*>>
unique_exact_definitions_;

// Same as unique_exact_definitions_ but for uses instead of definitions
std::unordered_map<
std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>,
std::vector<Expr*>>
unique_exact_uses_;

//! Allocated Loop index variable through the CA map.
//! only valid for disjoint sets on the loop ca map.
std::unordered_map<const VectorOfUniqueEntries<IterDomain*>*, Val*>
Expand Down
44 changes: 41 additions & 3 deletions torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,18 @@ std::shared_ptr<ReductionParams> persistentHeuristic(
const int64_t max_persistent_buffer_size,
size_t vectorize_factor,
bool project_persistent_buffers) {
std::cout << "Within heuristics" << std::endl;
std::cout << "\n"
<< total_reduction_numel << "\n" // 945
<< total_iteration_numel << "\n" // 42
<< inner_most_dimension_numel << "\n" // 63
<< fastest_dim_reduction << "\n" // 1
<< n_tensor_inputs << "\n" // 2
<< max_input_dtype_size << "\n" // 4
<< max_persistent_buffer_size << "\n" // 0
<< vectorize_factor << "\n" // 1
<< project_persistent_buffers << std::endl; // 1

std::shared_ptr<ReductionParams> rparams;
if (fastest_dim_reduction) {
rparams = innerPersistentHeuristic(
Expand All @@ -788,13 +800,15 @@ std::shared_ptr<ReductionParams> persistentHeuristic(
vectorize_factor);
}
rparams->project_persistent_buffers = project_persistent_buffers;
std::cout<<"Finished within heuristics"<<std::endl;
return rparams;
}

TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache) {
std::cout<<"???"<<std::endl;
FUSER_PERF_SCOPE("getPersistentHeuristics");

FusionGuard fg(fusion);
Expand Down Expand Up @@ -847,12 +861,20 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
scheduler_utils::getProperties(fusion, runtime_info, first_red_tv);

// Grab persistent buffer sizes
fusion->printMath();
auto persistent_buffer_size_info = scheduler_utils::persistentBufferSize(
fusion, runtime_info, persistent_buffer_info, data_cache);
// If projected persistent buffers are smaller, they will be used.
auto max_persistent_size = std::min(
persistent_buffer_size_info.persistent_buffer_size,
persistent_buffer_size_info.projected_persistent_buffer_size);
// TODO: projected buffers coming up as 0 in NVFuserTest.FusionViewMagicSchedule7_CUDA
auto max_persistent_size = ir_utils::getViewOps(fusion).size() > 0
? persistent_buffer_size_info.persistent_buffer_size
: std::min(
persistent_buffer_size_info.persistent_buffer_size,
persistent_buffer_size_info.projected_persistent_buffer_size);

std::cout << persistent_buffer_size_info.persistent_buffer_size << " :: "
<< persistent_buffer_size_info.projected_persistent_buffer_size
<< std::endl;

// Figure out if we want to projet persistent buffers to the inputs for
// exmaple if we have an input tensor t0 that's fp16:
Expand Down Expand Up @@ -908,6 +930,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
vectorize_factor = 1;
}

std::cout<<"A"<<std::endl;
// Try expanding vectorization to contig merged domains
vectorize_factor = vectorize_helper::expandVectorizationToContigMergedDomains(
fusion,
Expand All @@ -917,6 +940,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
(int)(first_red_tv->nDims() - properties.inner_most_dimension_ndims),
vectorize_factor);

std::cout<<"B"<<std::endl;
// Base max dtype and n_tensor_inputs on tensors that are vectorizable (i.e.
// share inner dimension with data pattern we're looking at).
size_t max_dtype_size = 1;
Expand Down Expand Up @@ -996,6 +1020,18 @@ TORCH_CUDA_CU_API void schedulePersistentKernel(
// changes registry needs to change.
auto reduction_tv = reduction_tvs[0];

if (ir_utils::getViewOps(fusion).size() > 0) {
ComputeAtMap ca_map(fusion);
// Propagate view transforms through the graph, expecially the reference.
scheduler_utils::propagateViewTransforms(fusion, ca_map);

// Reorder reference_tv after propagating the view operation. This will
// reorder for better merging.
reduction_tv->reorder(
scheduler_utils::domainReorderAsRfactorMap(reduction_tv));
fusion->printMath();
}

auto dim_analysis = scheduler_utils::canonicalDimReduction(
fusion, reduction_tv, rparams.fastest_dim && rparams.schedule_3D);
bool has_iter_axis = dim_analysis.first;
Expand Down Expand Up @@ -1028,6 +1064,8 @@ TORCH_CUDA_CU_API void schedulePersistentKernel(
reduction_tvs,
cached_inputs,
cached_outputs);

fusion->printMath();
}

} // namespace cuda
Expand Down
Loading