Skip to content

Commit

Permalink
Support broadcasts of predicated tensors within thread blocks (#100)
Browse files Browse the repository at this point in the history
This intends to support cases where a reduced tensor is an input to a parallelized tensor. So, broadcasts after reductions like below should work now within a thread block:

t1 = sum(t0, {1});
t2 = broadcast(t1, {false, true});

The major changes include:

- Add blockBroadcast device function, which is used for broadcasting to dimensions parallelized with TIDx/y/z.
- Update the softmax test. It now matches with the Aten output (within a relaxed threshold).
- Add a simplified softmax test, which does not do input normalization with max.
- Refactor thread predicate computation. Thread predicate information is necessary for both lowering and printing, so I extracted that from the lowering and make it a more independent class.

Limitations and concerns:

- Broadcasting to BID-parallelized dimensions are not supported
- Thread predicates are computed twice, which might be a performance concern, but still should be trivially small compared to, e.g., the computeAt implementation.
  • Loading branch information
naoyam authored Jul 8, 2020
1 parent 0568040 commit 1eca373
Show file tree
Hide file tree
Showing 14 changed files with 810 additions and 277 deletions.
401 changes: 327 additions & 74 deletions test/cpp/jit/test_gpu.cpp

Large diffs are not rendered by default.

144 changes: 74 additions & 70 deletions test/cpp/jit/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,76 +97,80 @@ namespace jit {
_(FusionAliasing)

#if defined(USE_CUDA)
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
_(CompleteArgumentSpec) \
_(Fusion) \
_(GraphExecutor) \
_(ModuleConversion) \
_(Interp) \
_(GPU_IrGraphGenerator) \
_(GPU_FusionDispatch) \
_(GPU_FusionClear) \
_(GPU_FusionCopy) \
_(GPU_FusionMove) \
_(GPU_FusionSimpleArith) \
_(GPU_FusionExprEvalConstants) \
_(GPU_FusionExprEvalBindings) \
_(GPU_FusionExprEvalBasic) \
_(GPU_FusionExprEvalComplex) \
_(GPU_FusionExprEvalPostLower) \
_(GPU_FusionSimpleTypePromote) \
_(GPU_FusionMutator) \
_(GPU_FusionRegister) \
_(GPU_FusionTopoSort) \
_(GPU_FusionTensor) \
_(GPU_FusionTVSplit) \
_(GPU_FusionTVMerge) \
_(GPU_FusionTVReorder) \
_(GPU_FusionEquality) \
_(GPU_FusionReplaceAll) \
_(GPU_FusionParser) \
_(GPU_FusionDependency) \
_(GPU_FusionCodeGen) \
_(GPU_FusionCodeGen2) \
_(GPU_FusionSimplePWise) \
_(GPU_FusionExecKernel) \
_(GPU_FusionForLoop) \
_(GPU_FusionLoopUnroll) \
_(GPU_FusionUnaryOps) \
_(GPU_FusionBinaryOps) \
_(GPU_FusionTernaryOps) \
_(GPU_FusionCompoundOps) \
_(GPU_FusionCastOps) \
_(GPU_FusionAdvancedComputeAt) \
_(GPU_FusionScalarInputs) \
_(GPU_FusionRFactorReplay) \
_(GPU_FusionReduction) \
_(GPU_FusionReduction2) \
_(GPU_FusionReduction3) \
_(GPU_FusionReduction4) \
_(GPU_FusionReduction5) \
_(GPU_FusionReductionTFT) \
_(GPU_FusionSimpleBCast) \
_(GPU_FusionSimpleGemm) \
_(GPU_FusionSoftmax) \
_(GPU_FusionSoftmaxComputeAt) \
_(GPU_FusionGridReduction1) \
_(GPU_FusionGridReduction2) \
_(GPU_FusionGridReduction3dim1) \
_(GPU_FusionGridReduction3dim0) \
_(GPU_FusionGridReduction4) \
_(GPU_FusionGridReduction5) \
_(GPU_FusionGridReduction6) \
_(GPU_FusionNonRedAxisBind) \
_(GPU_FusionBCastInnerDim) \
_(GPU_FusionBCastReduce) \
_(GPU_FusionSplitBCast) \
_(GPU_FusionComputeAtExprOrder) \
_(GPU_FusionZeroDimComputeAt) \
_(GPU_FusionZeroDimBroadcast) \
_(GPU_FusionZeroDimReduction) \
_(GPU_FusionReductionMultiConsumer)
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
_(CompleteArgumentSpec) \
_(Fusion) \
_(GraphExecutor) \
_(ModuleConversion) \
_(Interp) \
_(GPU_IrGraphGenerator) \
_(GPU_FusionDispatch) \
_(GPU_FusionClear) \
_(GPU_FusionCopy) \
_(GPU_FusionMove) \
_(GPU_FusionSimpleArith) \
_(GPU_FusionExprEvalConstants) \
_(GPU_FusionExprEvalBindings) \
_(GPU_FusionExprEvalBasic) \
_(GPU_FusionExprEvalComplex) \
_(GPU_FusionExprEvalPostLower) \
_(GPU_FusionSimpleTypePromote) \
_(GPU_FusionMutator) \
_(GPU_FusionRegister) \
_(GPU_FusionTopoSort) \
_(GPU_FusionTensor) \
_(GPU_FusionTVSplit) \
_(GPU_FusionTVMerge) \
_(GPU_FusionTVReorder) \
_(GPU_FusionEquality) \
_(GPU_FusionReplaceAll) \
_(GPU_FusionParser) \
_(GPU_FusionDependency) \
_(GPU_FusionCodeGen) \
_(GPU_FusionCodeGen2) \
_(GPU_FusionSimplePWise) \
_(GPU_FusionExecKernel) \
_(GPU_FusionForLoop) \
_(GPU_FusionLoopUnroll) \
_(GPU_FusionUnaryOps) \
_(GPU_FusionBinaryOps) \
_(GPU_FusionTernaryOps) \
_(GPU_FusionCompoundOps) \
_(GPU_FusionCastOps) \
_(GPU_FusionAdvancedComputeAt) \
_(GPU_FusionScalarInputs) \
_(GPU_FusionRFactorReplay) \
_(GPU_FusionReduction) \
_(GPU_FusionReduction2) \
_(GPU_FusionReduction3) \
_(GPU_FusionReduction4) \
_(GPU_FusionReduction5) \
_(GPU_FusionReductionTFT) \
_(GPU_FusionSimpleBCast) \
_(GPU_FusionSimpleGemm) \
_(GPU_FusionSoftmax1D) \
_(GPU_FusionSoftmax1DNormalized) \
_(GPU_FusionSoftmax3D) \
_(GPU_FusionSoftmax3DNormalized) \
_(GPU_FusionSoftmaxComputeAt) \
_(GPU_FusionGridReduction1) \
_(GPU_FusionGridReduction2) \
_(GPU_FusionGridReduction3dim1) \
_(GPU_FusionGridReduction3dim0) \
_(GPU_FusionGridReduction4) \
_(GPU_FusionGridReduction5) \
_(GPU_FusionGridReduction6) \
_(GPU_FusionNonRedAxisBind) \
_(GPU_FusionBCastInnerDim) \
_(GPU_FusionBCastReduce) \
_(GPU_FusionSplitBCast) \
_(GPU_FusionComputeAtExprOrder) \
_(GPU_FusionZeroDimComputeAt) \
_(GPU_FusionZeroDimBroadcast) \
_(GPU_FusionZeroDimReduction) \
_(GPU_FusionReductionMultiConsumer) \
_(GPU_FusionBCastAfterReduce)
#else
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
Expand Down
63 changes: 54 additions & 9 deletions torch/csrc/jit/codegen/cuda/ir_iostream.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>

#include <iostream>

Expand Down Expand Up @@ -478,15 +480,50 @@ void IRPrinter::handle(const ReductionOp* rop) {
}

void IRPrinter::handle(const BroadcastOp* bop) {
indent();
handle(bop->out());
os << "\n";
indent_size++;
indent();
os << " = ";
handle(bop->in());
indent_size--;
os << ";\n";
// Check if we've lowered yet.
bool lowered = bop->out()->getValType() == ValType::TensorIndex;
if (!lowered) {
os << bop->out() << " = broadcast( " << bop->in() << " )\n";
return;
}

const ir_utils::ParallelTypeBitmap domains =
ir_utils::getParallelBroadcastDomains(bop, getThreadPredicateMap());
const bool thread_x = domains.get(ParallelType::TIDx);
const bool thread_y = domains.get(ParallelType::TIDy);
const bool thread_z = domains.get(ParallelType::TIDz);
const bool block_x = domains.get(ParallelType::BIDx);
const bool block_y = domains.get(ParallelType::BIDy);
const bool block_z = domains.get(ParallelType::BIDz);

const bool grid_broadcast_needed = block_x || block_y || block_z;
const bool block_broadcast_needed = thread_x || thread_y || thread_z;

TORCH_INTERNAL_ASSERT(
!grid_broadcast_needed, "Parallel broadcast across blocks not supported");

if (block_broadcast_needed) {
indent();
os << "broadcast::blockBroadcast<";
os << (thread_x ? "true" : "false") << ", ";
os << (thread_y ? "true" : "false") << ", ";
os << (thread_z ? "true" : "false");
os << ">(";
handle(bop->out());
os << ", ";
handle(bop->in());
os << ");\n";
} else {
indent();
handle(bop->out());
os << "\n";
indent_size++;
indent();
os << " = ";
handle(bop->in());
indent_size--;
os << ";\n";
}
}

void IRPrinter::handle(const ForLoop* fl) {
Expand Down Expand Up @@ -640,6 +677,14 @@ void IRPrinter::printKernel(
os << "}\n";
}

const ThreadPredicateMap& IRPrinter::getThreadPredicateMap() {
if (thread_predicates_ == nullptr) {
Fusion* fusion = FusionGuard::getCurFusion();
thread_predicates_ = std::make_unique<ThreadPredicateMap>(fusion);
}
return *thread_predicates_;
}

std::ostream& operator<<(std::ostream& os, const Statement* stmt) {
IRPrinter p(os);
p.handle(stmt);
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_iostream.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/csrc/WindowsTorchApiMacro.h>

#include <torch/csrc/jit/codegen/cuda/dispatch.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>

#include <iostream>

Expand Down Expand Up @@ -126,6 +127,11 @@ class TORCH_CUDA_API IRPrinter : public OptInConstDispatch {
void printKernel(
const std::vector<Expr*>& exprs,
const std::string& kernel_name);

private:
std::unique_ptr<ThreadPredicateMap> thread_predicates_;

const ThreadPredicateMap& getThreadPredicateMap();
};

TORCH_CUDA_API std::ostream& operator<<(
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ std::pair<std::string, std::string> codeGeneration(Fusion* fusion) {
<< code_random_number_gen << "\n"
<< code_helper_funcs << "\n"
<< code_template_block_reduction << "\n"
<< code_template_grid_reduction << "\n";
<< code_template_grid_reduction << "\n"
<< code_template_block_broadcast << "\n";
std::stringstream cdg;
GPULower gpulw(fusion);
gpulw.printKernel(str_stream, kKernelName);
Expand Down
47 changes: 47 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel_resource_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,53 @@ __device__ void gridReduce(T& out, T inp_val, Func reduction_op,
} // namespace reduction
)";

static auto code_template_block_broadcast = R"(
namespace broadcast {
template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD>
__host__ __device__ unsigned offset_of_source(const dim3& block_dim, const dim3& thread_idx) {
unsigned offset = 0;
if (!Z_THREAD)
offset = offset * block_dim.z + thread_idx.z;
if (!Y_THREAD)
offset = offset * block_dim.y + thread_idx.y;
if (!X_THREAD)
offset = offset * block_dim.x + thread_idx.x;
return offset;
}
/** Broadcasts within partitioned groups of threads.
X_THREAD: Broadcast from threadIdx.x == 0 if true
Y_THREAD: Broadcast from threadIdx.y == 0 if true
Z_THREAD: Broadcast from threadIdx.z == 0 if true
inp_val: Per-thread source value. Only valid when the thread is a source.
out: Per-thread output location
*/
template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename T>
__device__ void blockBroadcast(T& out, T inp_val) {
// Use worst case for memory.
__shared__ T shared_mem[1024];
const bool has_valid_data =
(!X_THREAD || threadIdx.x == 0) &&
(!Y_THREAD || threadIdx.y == 0) &&
(!Z_THREAD || threadIdx.z == 0);
const auto shared_offset = offset_of_source<X_THREAD, Y_THREAD, Z_THREAD>(blockDim, threadIdx);
if (has_valid_data)
shared_mem[shared_offset] = inp_val;
__syncthreads();
out = shared_mem[shared_offset];
}
} // namespace broadcast
)";

} // namespace cuda
} // namespace fuser
} // namespace jit
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ std::vector<Expr*> GPULower::getLoweredExprs() {
// Validate and make some minor modifications in preparation to generate code.
PrepareForLowering(fusion_);

auto preds = ThreadPredicates::compute(fusion_);
ThreadPredicateMap preds(fusion_);

// Run our passes keeping the lowered expressions and forwarding them.
auto loop_nests = LoopNestGenerator::getLoopNest(
Expand Down
12 changes: 6 additions & 6 deletions torch/csrc/jit/codegen/cuda/lower_loops.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <torch/csrc/jit/codegen/cuda/dispatch.h>

#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>

namespace torch {
namespace jit {
namespace fuser {
Expand Down Expand Up @@ -40,7 +42,7 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch {

// Predicates from ThreadPredicates that we will extend to reduction buffer
// initialization
std::unordered_map<const TensorView*, Bool*>& thread_predicates_;
ThreadPredicateMap& thread_predicates_;

// Create, place, and return the allocation for tv
Expr* pushAlloc(TensorView*);
Expand Down Expand Up @@ -71,16 +73,14 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch {
// Run the pass and accumulate output in lowered_exprs
void generate(const std::vector<Expr*>& exprs);

LoopNestGenerator(
Fusion* _fusion,
std::unordered_map<const TensorView*, Bool*>& _thread_predicates)
LoopNestGenerator(Fusion* _fusion, ThreadPredicateMap& _thread_predicates)
: fusion_(_fusion), thread_predicates_(_thread_predicates) {}

public:
static std::vector<Expr*> getLoopNest(
Fusion* fusion,
std::vector<Expr*> exprs,
std::unordered_map<const TensorView*, Bool*>& thread_predicates) {
ThreadPredicateMap& thread_predicates) {
FusionGuard fg(fusion);
LoopNestGenerator lng(fusion, thread_predicates);
lng.generate(exprs);
Expand All @@ -90,4 +90,4 @@ class TORCH_CUDA_API LoopNestGenerator : public OptOutDispatch {

} // namespace fuser
} // namespace jit
} // namespace torch
} // namespace torch
Loading

0 comments on commit 1eca373

Please sign in to comment.