Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IterDomain resize for pad, cat, slice #3

Merged
merged 4 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ if(BUILD_TEST)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_outer_reduction.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_loop_rotation.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_shift.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_resize.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_tensorcore.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_matmul_sass.cpp)
list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_view.cpp)
Expand Down
35 changes: 35 additions & 0 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2774,6 +2774,41 @@ class CudaKernelGenerator : private OptOutConstDispatch {
indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
}

void handle(const CatOp* cat) final {
auto out = gen(cat->output(0));

// Generate code like:
// if (consumer_idx < producer_0_extent) {
// consumer[consumer_idx] = produce_0[producer_idx0];
// } else if (consumer_idx < producer_1_extent) {
// consumer[consumer_idx] = produce_1[producer_idx1];
// } else if (consumer_idx < producer_2_extent) {
// consumer[consumer_idx] = produce_2[producer_idx2];
// } else {
// consumer[consumer_idx] = produce_3[producer_idx3];
// }

for (const auto i : c10::irange(cat->inputs().size())) {
auto inp = cat->input(i)->as<kir::TensorIndex>();
auto inp_str = gen(inp);
if (i < cat->inputs().size() - 1) {
if (i == 0) {
indent() << "if (";
} else {
indent() << "} else if (";
}
code_ << gen(cat->getPred(i)) << ") {\n";
} else {
// last case doesn't need to be predicated
indent() << "} else {\n";
}

indent() << kTab << out << " = " << gen(inp) << ";\n";
}

indent() << "}\n";
}

private:
std::stringstream code_;
const kir::Kernel* kernel_;
Expand Down
78 changes: 70 additions & 8 deletions csrc/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ bool IterDomainGraph::exprsMap(
}

TORCH_INTERNAL_ASSERT(
first->isA<Merge>() || first->isA<Split>(),
"Merge and split are the only expressions supported through rfactor operations in compute at map, but found:\n",
first->isA<Merge>() || first->isA<Split>() || first->isA<Resize>(),
"Merge, split and resize are the only expressions supported through rfactor operations in compute at map, but found:\n",
first->toString());

auto first_ids = ir_utils::filterByType<IterDomain>(
Expand Down Expand Up @@ -176,6 +176,15 @@ bool IterDomainGraph::exprsMap(
}
}

if (first->isA<Resize>()) {
auto first_resize = first->as<Resize>();
auto second_resize = second->as<Resize>();
if (!first_resize->leftExpand()->sameAs(second_resize->leftExpand()) ||
!first_resize->rightExpand()->sameAs(second_resize->rightExpand())) {
return false;
}
}

return true;
}

Expand Down Expand Up @@ -211,6 +220,7 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) {
for (auto out_i : c10::irange(first_ids.size())) {
exact_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
permissive_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
permissive_resize_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]);
}
}

Expand Down Expand Up @@ -407,6 +417,7 @@ void IterDomainGraph::build(Fusion* fusion) {
auto id0 = *disjoint_set->begin();
for (auto id1 : disjoint_set->vector()) {
permissive_nodes_.mapEntries(id0, id1);
permissive_resize_nodes_.mapEntries(id0, id1);
exact_nodes_.mapEntries(id0, id1);
sibling_sets_.mapEntries(id0, id1);
}
Expand All @@ -430,8 +441,22 @@ void IterDomainGraph::build(Fusion* fusion) {
// Look for matching ID transformations in producer and consumer, replay
// producer as consumer. We use the symmetric API of BestEffortReplay so
// that both broadcast and squeeze are handled correctly.
//
// Note on the boolean flags: swizzles are skipped in both
// producer and consumer but resizes are not.
const auto permissive_disjoint_sets =
BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map)
BestEffortReplay::replayPasC(
p_tv, c_tv, -1, pairwise_map, true, true, false)
.getIterDomainEquivalence();

// Permissive-Resize map allows mappings of resize inputs and
// outputs
//
// Note on the boolean flags: swizzles and resizes are skipped
// in the permissive-resize map
const auto permissive_resize_disjoint_sets =
BestEffortReplay::replayPasC(
p_tv, c_tv, -1, pairwise_map, true, true, true)
.getIterDomainEquivalence();

// For exact mapings do not map any broadcast dimensions to
Expand Down Expand Up @@ -483,16 +508,12 @@ void IterDomainGraph::build(Fusion* fusion) {
for (auto j : c10::irange(i + 1, vec.size())) {
auto id2 = vec[j];
if (p_ids.count(id1) && c_ids.count(id2)) {
consumers_.at(id1).pushBack(id2);
producers_.at(id2).pushBack(id1);
if (idIsAComputeAtLeafDomain(id1, p_tv, c_tv) &&
idIsALeafDomain(id2, c_tv)) {
loop_nodes_.mapEntries(id1, id2);
}
}
if (c_ids.count(id1) && p_ids.count(id2)) {
producers_.at(id1).pushBack(id2);
consumers_.at(id2).pushBack(id1);
if (idIsAComputeAtLeafDomain(id2, p_tv, c_tv) &&
idIsALeafDomain(id1, c_tv)) {
loop_nodes_.mapEntries(id1, id2);
Expand All @@ -501,6 +522,31 @@ void IterDomainGraph::build(Fusion* fusion) {
}
}
}

// Mostly the same as the above for the permissive map but
// nothing to do for the loop map.
// The producer and consumer maps are based on the most
// permissive mappings, so they are set using the
// permissive-resize mappings.
for (auto& dset : permissive_resize_disjoint_sets.disjointSets()) {
auto& vec = dset->vector();
for (auto i : c10::irange(vec.size())) {
auto id1 = vec[i];
permissive_resize_nodes_.mapEntries(id1, vec[0]);
mapMaybeSwizzleOp(permissive_resize_nodes_, id1);
for (auto j : c10::irange(i + 1, vec.size())) {
auto id2 = vec[j];
if (p_ids.count(id1) && c_ids.count(id2)) {
consumers_.at(id1).pushBack(id2);
producers_.at(id2).pushBack(id1);
}
if (c_ids.count(id1) && p_ids.count(id2)) {
producers_.at(id1).pushBack(id2);
consumers_.at(id2).pushBack(id1);
}
}
}
}
}
}
}
Expand Down Expand Up @@ -561,7 +607,7 @@ void IterDomainGraph::build(Fusion* fusion) {
for (auto expr : exprs) {
auto rfactor_inp_ids = ir_utils::filterByType<IterDomain>(expr->inputs());
TORCH_INTERNAL_ASSERT(
expr->isA<Split>() || expr->isA<Merge>(),
expr->isA<Split>() || expr->isA<Merge>() || expr->isA<Resize>(),
"Wasn't expecting the expression type of:\n",
expr->toString(),
"\nto be an expression defined in an rfactor transformation.");
Expand Down Expand Up @@ -688,6 +734,7 @@ void IterDomainGraph::initializeId(
bool is_rfactor_id,
bool is_leaf_id) {
permissive_nodes_.initializeSet(id);
permissive_resize_nodes_.initializeSet(id);
exact_nodes_.initializeSet(id);
if (is_leaf_id) {
loop_nodes_.initializeSet(id);
Expand Down Expand Up @@ -1127,6 +1174,17 @@ void ComputeAtMap::buildConcreteIds() {
auto concrete_id = computeConcreteId(first_id, IdMappingMode::LOOP);
concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id;
}

for (const auto& disjoint_set_shared_ptr :
id_graph_.permissiveResizeNodes().disjointSets()) {
TORCH_INTERNAL_ASSERT(
disjoint_set_shared_ptr->vector().size(),
"Cannot compute concrete id of empty set.");
auto first_id = disjoint_set_shared_ptr->vector().front();
auto concrete_id =
computeConcreteId(first_id, IdMappingMode::PERMISSIVE_RESIZE);
concrete_id_cache_[disjoint_set_shared_ptr] = concrete_id;
}
}

bool ComputeAtMap::areExactExprs(Expr* expr_1, Expr* expr_2) {
Expand Down Expand Up @@ -1349,6 +1407,8 @@ std::string ComputeAtMap::toString() const {
ss << "Loop map:\n" << idGraphNodesToString(*this, IdMappingMode::LOOP);
ss << "Permissive map:\n"
<< idGraphNodesToString(*this, IdMappingMode::PERMISSIVE);
ss << "Permissive-Resize map:\n"
<< idGraphNodesToString(*this, IdMappingMode::PERMISSIVE_RESIZE);
ss << "Consumer maps:\n";
for (auto key : getSortedKeys(id_graph_.consumers(), Statement::lessThan)) {
auto consumers = id_graph_.consumers().at(key);
Expand Down Expand Up @@ -1408,6 +1468,8 @@ const DisjointSets<IterDomain*>& ComputeAtMap::getIdSets(
return id_graph_.loopNodes();
case IdMappingMode::PERMISSIVE:
return id_graph_.permissiveNodes();
case IdMappingMode::PERMISSIVE_RESIZE:
return id_graph_.permissiveResizeNodes();
}
TORCH_INTERNAL_ASSERT(false, "Error with mapping mode provided.");
}
Expand Down
12 changes: 11 additions & 1 deletion csrc/compute_at_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ namespace nvfuser {
// Map all iteration domains
// Always contain root mappings (otherwise they could have been forwarded in
// broadcast)
// IdMappingMode::PERMISSIVE_RESIZE
// Include everything in PERMISSIVE. Map also domains that are
// inputs and outputs of resize ops. Used for, e.g., propagating
// parallel types across those domains.
// IdMappingMode::EXACT
// Don't map any broadcast axes to non-broadcast axes
// Do not forward through any broadcast IDs
Expand All @@ -79,6 +83,9 @@ class TORCH_CUDA_CU_API IterDomainGraph {
const DisjointSets<IterDomain*>& loopNodes() const {
return loop_nodes_;
}
const DisjointSets<IterDomain*>& permissiveResizeNodes() const {
return permissive_resize_nodes_;
}

// Consumers and producers is not symmetric like the other sets
const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
Expand Down Expand Up @@ -132,8 +139,11 @@ class TORCH_CUDA_CU_API IterDomainGraph {
DisjointSets<IterDomain*> exact_nodes_;
DisjointSets<IterDomain*> almost_exact_nodes_;
DisjointSets<IterDomain*> loop_nodes_;
DisjointSets<IterDomain*> permissive_resize_nodes_;

// Consumers and producers is not symmetric like the other sets
// Consumers and producers is not symmetric like the other sets.
// Mapping is based on the most permissive map, i.e., the
// permissive-resize map.
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
consumers_;
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
Expand Down
61 changes: 61 additions & 0 deletions csrc/contiguity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,48 @@ void OrderedIdInformation::handle(Swizzle2D* swizzle) {
}
}

void OrderedIdInformation::handle(Resize* resize) {
// Find inputs in the active_ids_ vector
const auto in_it =
std::find(active_ids_.begin(), active_ids_.end(), resize->in());

if (in_it == active_ids_.end()) {
return;
}

auto in_pos = std::distance(active_ids_.begin(), in_it);

// Find inputs in the ordered transforms map
const auto in_ordered_it = consistently_ordered_ids_.find(resize->in());

bool in_ordered = in_ordered_it != consistently_ordered_ids_.end();

// Get root ids of the two inputs
const auto in_root_ids_it = id_to_root_ids_.find(resize->in());

TORCH_INTERNAL_ASSERT(
in_root_ids_it != id_to_root_ids_.end(),
"Error replaying transforms in contiguous ID checker.");

const auto& in_root_ids = in_root_ids_it->second;

// Update map for outputs
// Remove inputs from the active_ids_ and insert the output ID
active_ids_[in_pos] = resize->out();

// Not completely certain, but propagating these properties should e
// fine
if (in_ordered) {
consistently_ordered_ids_.emplace(resize->out());
}

if (exclusivelyConsumesRoots(resize->in())) {
exclusively_consumes_roots_.emplace(resize->out());
}

id_to_root_ids_[resize->out()] = in_root_ids;
}

NonDivisibleSplitDependencies::NonDivisibleSplitDependencies(
// TODO: Revisit reduction rfactor axes and propagation. Should probably use
// ca_map to propogate non divisibility dependencies across exact map. Still
Expand Down Expand Up @@ -500,6 +542,19 @@ void ContigIDs::build(const std::vector<IterDomain*>& ids) {
{root_domain_.begin(), root_domain_.end()},
{ids.begin(), ids.end()});
for (auto expr : exprs) {
if (auto resize = dynamic_cast<Resize*>(expr)) {
resize_deps_.insert(resize->out());
} else {
if (std::any_of(
expr->inputs().begin(), expr->inputs().end(), [&](Val* inp) {
return inp->isA<IterDomain>() &&
resize_deps_.count(inp->as<IterDomain>());
})) {
for (auto out : ir_utils::filterByType<IterDomain>(expr->outputs())) {
resize_deps_.insert(out);
}
}
}
handle(expr);
}
}
Expand Down Expand Up @@ -576,6 +631,12 @@ void ContigIDs::handle(Merge* merge) {
return;
}

// Don't allow contig indexing after resize as we need traverse back
// at least to direct outputs of resize ops
if (resize_deps_.count(merge->out())) {
return;
}

// All broadcasting
if (last_root == nullptr) {
return;
Expand Down
7 changes: 7 additions & 0 deletions csrc/contiguity.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class OrderedIdInformation : public OptInDispatch {

void handle(Swizzle2D* swizzle) override;

void handle(Resize* resize) override;

// Track which root ids were used to generate each iter domain
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
id_to_root_ids_;
Expand Down Expand Up @@ -255,6 +257,8 @@ class ContigIDs : public OptInDispatch {
// cases, depending on specific swizzle type and axes.
void handle(Swizzle2D* swizzle) override {}

void handle(Resize* resize) override {}

IterDomain* getCAIndexConcreteId(IterDomain* id) const;

//! True if an ID is indexable.
Expand Down Expand Up @@ -307,6 +311,9 @@ class ContigIDs : public OptInDispatch {
std::unique_ptr<const OrderedIdInformation> consistent_transform_info_;

NonDivisibleSplitDependencies non_divisible_id_info_;

//! IDs that depend on resize output IDs
std::unordered_set<IterDomain*> resize_deps_;
};

} // namespace nvfuser
Loading