Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grouping grid allreduces across iterations #1755

Merged
merged 17 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 216 additions & 43 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,38 @@ std::string genCall(
return ss.str();
}

//! A utility class to check if an expression of a particular type exists
class ExprFinder : kir::ConstIrVisitor {
public:
//! True if expr or any of its nested expressions is included in
//! expr_types
static bool exists(
const Expr* expr,
const std::unordered_set<ExprType>& expr_types) {
ExprFinder finder(expr_types);
finder.handle(std::vector<const Expr*>{expr});
return finder.is_found_;
}

private:
ExprFinder(const std::unordered_set<ExprType>& expr_types)
: expr_types_(expr_types) {}

using kir::ConstIrVisitor::handle;

void handle(const Expr* expr) final {
if (expr_types_.find(expr->etype()) != expr_types_.end()) {
is_found_ = true;
return;
}
kir::ConstIrVisitor::handle(expr);
}

private:
const std::unordered_set<ExprType>& expr_types_;
bool is_found_ = false;
};

class CudaKernelGenerator : private OptOutConstDispatch {
static constexpr const char* kTab = " ";

Expand Down Expand Up @@ -397,6 +429,14 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}

void handle(const Int* i) final {
// Check the replacement map first. If there's an entry for i, use
// the corresponding replacement.
auto replace_it = index_replacement_map_.find(i);
if (replace_it != index_replacement_map_.end()) {
code_ << replace_it->second;
return;
}

const auto def = i->definition();
const bool has_alloc = alloc_map_.find(i) != alloc_map_.end();
if (def != nullptr && !has_alloc) {
Expand Down Expand Up @@ -1535,7 +1575,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}

TORCH_INTERNAL_ASSERT(
grouped_grop->numReductions() == 2,
grouped_grop->numExprs() == 2,
"Only grouping of 2 reductions is supported. ",
grouped_grop->toString());

Expand All @@ -1554,7 +1594,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
ArgumentBuilder func_args(block_nest_level_ + 1, kTab);

// Append arguments for each reduction
for (const auto i : c10::irange(grouped_grop->numReductions())) {
for (const auto i : c10::irange(grouped_grop->numExprs())) {
TORCH_INTERNAL_ASSERT(
grouped_grop->reduction_buffers().at(i)->buffer()->isA<TensorView>());
const auto work_buffer =
Expand Down Expand Up @@ -1596,17 +1636,106 @@ class CudaKernelGenerator : private OptOutConstDispatch {
indent() << kTab << func_args << ");\n";
}

// Enumerates all combinations of index values of grouped
// loops. Each combination is a vector of loop index values. The
// length of the vector is the number of grouped loops.
//
// Example 1: only one domain of extent 2 is grouped: {{0}, {1}}.
// Example 2: two domains of extents 2 and 3 are grouped: {{0, 0},
// {0, 1}, {0, 2}, {1, 0}, {1, 1}, {1, 2}}
std::vector<std::vector<int64_t>> getGroupedLoopIndexConcreteIntSets() {
std::vector<std::vector<int64_t>> index_combinationsatoins;

// Initialize with an empty vector
index_combinationsatoins.push_back(std::vector<int64_t>());

// Incrementally build a combinatorial set
for (const auto loop : grouped_loops_) {
const auto iter_count = loop->stop()->evaluateInt();
std::vector<std::vector<int64_t>> new_combinations;
// Append integers from 0 to iter_count to all the vectors built
// so far
for (const auto& index_vec : index_combinationsatoins) {
for (int64_t i = 0; i < iter_count; ++i) {
auto index_vec_appended = index_vec;
index_vec_appended.push_back(i);
new_combinations.push_back(index_vec_appended);
}
}
index_combinationsatoins = std::move(new_combinations);
}

return index_combinationsatoins;
}

//! Returns all combinations of maps from index Vals of grouped loops to their
//! conrete integers.
std::vector<std::unordered_map<const Int*, int64_t>>
getLoopIndexReplacementMaps() {
std::vector<std::unordered_map<const Int*, int64_t>> maps;

if (grouped_loops_.empty()) {
std::unordered_map<const Int*, int64_t> empty_map;
return {empty_map};
}

// Vector of indices of grouped loops
std::vector<Int*> loop_indices;
std::transform(
grouped_loops_.begin(),
grouped_loops_.end(),
std::back_inserter(loop_indices),
[](const kir::ForLoop* loop) { return loop->index()->as<Int>(); });

// All combinations of loop index integer values
const auto index_val_sets = getGroupedLoopIndexConcreteIntSets();

// Create maps from loop index Vals to integers
for (const auto& index_values : index_val_sets) {
TORCH_INTERNAL_ASSERT(loop_indices.size() == index_values.size());
std::unordered_map<const Int*, int64_t> index_val_map;
for (const auto i : c10::irange(loop_indices.size())) {
auto loop_index = loop_indices.at(i);
auto index_val = index_values.at(i);
index_val_map.emplace(loop_index, index_val);
}
maps.emplace_back(std::move(index_val_map));
}

return maps;
}

void generateGroupedGridAllreduce(
const kir::GroupedGridReduction* grouped_grop) {
TORCH_INTERNAL_ASSERT(grouped_grop->isAllreduce());

constexpr int max_num_reductions = 8;
// There are two dimensions of grouping: horizontal grouping and
// iteration grouping. The total number of individual reductions
// is the number of horizontal reductions * the extent of grouped
// iterations. All of them are packed into a single grid reduction
// call. The number of reductions is limited, and currently it is
// simply an error if exceeded. This could be avoided by
naoyam marked this conversation as resolved.
Show resolved Hide resolved
// decomposing grouped_grop into smaller groups within the
// limit. TODO: Support a larger number of reductions.

// First, enumerate all combinations of loop index values of
// grouped IterDomains. If only a single domain is grouped, this
// is simply just a 1D vector of integer from 0 to extent-1. If
// two domains are grouped, combinations of two integer vectors
// are returned. These loop index value vectors are returned as a
// map from loop index Vals to concrete int values.
const auto index_replacement_maps = getLoopIndexReplacementMaps();
const auto num_grouped_iterations = index_replacement_maps.size();

// This is also checked at the lowering validaiton time, so it
// isn't strictly necessary.
TORCH_INTERNAL_ASSERT(
grouped_grop->numReductions() <= max_num_reductions,
num_grouped_iterations * grouped_grop->numExprs() <=
kMaxNumGroupedReductions,
"Too many grouped reductions: ",
grouped_grop->toString(),
". Up to ",
max_num_reductions,
kMaxNumGroupedReductions,
" reductions are allowed.");

ArgumentBuilder types;
Expand All @@ -1620,44 +1749,65 @@ class CudaKernelGenerator : private OptOutConstDispatch {
ArgumentBuilder read_preds;
ArgumentBuilder write_preds;

for (const auto i : c10::irange(grouped_grop->numReductions())) {
const auto data_type = grouped_grop->outputs().at(i)->dtype();
TORCH_INTERNAL_ASSERT(
grouped_grop->reduction_buffers().at(i)->buffer()->isA<TensorView>());

types.arg(data_type);
for (const auto expr_index : c10::irange(grouped_grop->numExprs())) {
const auto data_type = grouped_grop->outputs().at(expr_index)->dtype();
TORCH_INTERNAL_ASSERT(grouped_grop->reduction_buffers()
.at(expr_index)
->buffer()
->isA<TensorView>());

// out
outputs.arg(gen(grouped_grop->outputs().at(i)));
for (const auto& group_index :
c10::irange(index_replacement_maps.size())) {
// Set the index replacement map with the concrete values of
// indices of grouped loops.
index_replacement_map_ = index_replacement_maps.at(group_index);

// inp
inputs.arg(gen(grouped_grop->inputs().at(i)));
types.arg(data_type);

// global_work_buffer
const auto work_buffer =
grouped_grop->reduction_buffers().at(i)->buffer()->as<TensorView>();
work_bufs.arg("&").append(varName(work_buffer)).append("[0]");

init_vals.arg(genInline(grouped_grop->initVal(i)));

reduction_ops.arg(genReductionOp(
grouped_grop->getReductionOpType(i),
grouped_grop->output(i)->dtype()));
// out
outputs.arg(gen(grouped_grop->outputs().at(expr_index)));

// inp
inputs.arg(gen(grouped_grop->inputs().at(expr_index)));

// global_work_buffer
const auto work_buffer = grouped_grop->reduction_buffers()
.at(expr_index)
->buffer()
->as<TensorView>();
// Separate Work buffer is used for each reduction.
auto work_buffer_offset = group_index == 0
? "0"
: (genInline(grouped_grop->buffer_stride()) + " * " +
std::to_string(group_index));
work_bufs.arg("&")
.append(varName(work_buffer))
.append("[")
.append(work_buffer_offset)
.append("]");
init_vals.arg(genInline(grouped_grop->initVal(expr_index)));

reduction_ops.arg(genReductionOp(
grouped_grop->getReductionOpType(expr_index),
grouped_grop->output(expr_index)->dtype()));

// read and write predicates
bool_types.arg("bool");
// Same argument for all inputs. Different predicates would be
// used when grouping is done across iterations
TORCH_INTERNAL_ASSERT(
grouped_grop->predicate() != nullptr &&
grouped_grop->predicate()->hasValue());
const auto read_pred = genInline(grouped_grop->predicate());
read_preds.arg(read_pred);
if (grouped_grop->writePredicate() != nullptr) {
TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue());
write_preds.arg(genInline(grouped_grop->writePredicate()));
} else {
write_preds.arg(read_pred);
}

// read and write predicates
bool_types.arg("bool");
// Same argument for all inputs. Different predicates would be
// used when grouping is done across iterations
TORCH_INTERNAL_ASSERT(
grouped_grop->predicate() != nullptr &&
grouped_grop->predicate()->hasValue());
const auto read_pred = genInline(grouped_grop->predicate());
read_preds.arg(read_pred);
if (grouped_grop->writePredicate() != nullptr) {
TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue());
write_preds.arg(genInline(grouped_grop->writePredicate()));
} else {
write_preds.arg(read_pred);
index_replacement_map_.clear();
}
}

Expand Down Expand Up @@ -1975,7 +2125,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {

void handleTrivialLoop(const kir::ForLoop* loop) {
if (loop->vectorize()) {
vectorize_scope_ = loop->vectorize();
vectorize_scope_ = true;
}
handleScope(loop->body());
if (loop->vectorize()) {
Expand All @@ -1984,7 +2134,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}

void handle(const GroupedReductionOp* grouped_rop) final {
for (const auto i : c10::irange(grouped_rop->numReductions())) {
for (const auto i : c10::irange(grouped_rop->numExprs())) {
TORCH_INTERNAL_ASSERT(grouped_rop->output(i)->isA<kir::TensorIndex>());

const auto output = grouped_rop->output(i)->as<kir::TensorIndex>();
Expand All @@ -1997,7 +2147,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {

TORCH_INTERNAL_ASSERT(
!has_grid_reduce,
"GroupedReductionOp does not support block parallelization. GroupedGridReductionOp must be used. ",
"GroupedReductionOp does not support block parallelization. GroupedGridReduction must be used. ",
grouped_rop->toString());

if (!has_block_reduce) {
Expand All @@ -2023,12 +2173,32 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}

//! True if loop is grouped. The IterDomain of the loop must have
//! ParallelType::Group, but it isn't sufficient as the loop may be
//! for an initialization expression, for which the loop shold not
//! be grouped. Make sure a GroupedGridReduction is found.
bool isGroupedLoop(const kir::ForLoop* loop) {
if (loop->iter_domain()->getParallelType() != ParallelType::Group) {
return false;
}
return ExprFinder::exists(loop, {ExprType::GroupedGridReduction});
}

void handle(const kir::ForLoop* loop) final {
if (loop->isTrivial()) {
handleTrivialLoop(loop);
return;
}

// If a loop is grouped, no loop is created, but it isn't
// considered trivial as the loop trip count is not one.
if (isGroupedLoop(loop)) {
grouped_loops_.push_back(loop);
handleScope(loop->body());
grouped_loops_.pop_back();
return;
}

const auto gen_index = gen(loop->index());
const auto gen_start = genInline(loop->start());
const auto gen_stop = genInline(loop->stop());
Expand Down Expand Up @@ -2213,10 +2383,13 @@ class CudaKernelGenerator : private OptOutConstDispatch {

// Mark when we are inside of a vectorized for-loop
bool vectorize_scope_ = false;

//! Keep track of Allocate node for Val. Used to determine if Val
//! should be inlined.
std::unordered_map<const Val*, const kir::Allocate*> alloc_map_;
//! Keep track of grouped loops
std::deque<const kir::ForLoop*> grouped_loops_;
//! Used to replace symbolic indices with concrete values
std::unordered_map<const Int*, int64_t> index_replacement_map_;
};

} // namespace
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/inline_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ bool MaxPosCalculator::isAllowedID(
}

if (!allow_vectorize) {
// Avoid inlining if marked as Vectorize or Group. In the case of
// BestEffort and MostInlined modes, avoid Unroll as well.
bool is_vectorize = isParallelTypeVectorize(id->getParallelType()) ||
id->getParallelType() == ParallelType::Group ||
((mode_ == ComputeAtMode::BestEffort ||
mode_ == ComputeAtMode::MostInlined) &&
id->getParallelType() == ParallelType::Unroll);
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr {

GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner);

size_t numReductions() const {
//! Number of expressions grouped horizontally. It does not reflect
//! iteration grouping.
size_t numExprs() const {
return reduction_op_types_.size();
}

Expand All @@ -231,7 +233,9 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr {
bool sameAs(const Statement* other) const override;

private:
//! Reduction ops of grouped reductions
const std::vector<BinaryOpType> reduction_op_types_;
//! Initial values of grouped reductions
const std::vector<Val*> init_vals_;
//! True if using the fused reduction kernel
bool is_allreduce_ = false;
Expand Down
Loading