Skip to content

Commit

Permalink
Enable transpose scheduler (#1927)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Sep 11, 2022
1 parent b7a206e commit bd93578
Show file tree
Hide file tree
Showing 7 changed files with 405 additions and 144 deletions.
44 changes: 43 additions & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,18 @@ namespace HeuristicCompileTime {
//! Enum for all possible types of cached entries of compile-time info.
enum class CompileTimeEntryType {
DOMAIN_MAP,
TRANSPOSE_DOMAIN_MAP,
REFERENCE_TENSORS,
REFERENCE_TENSORS_FOR_GROUPS,
VECTORIZABLE_INPUTS_AND_OUTPUTS,
INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS,
UNROLLABLE_INPUTS_AND_OUTPUTS,
REDUCTION_TVS,
PERSISTENT_BUFFER_INFO,
SCOPE_PERSISTENT_FACTOR_INFO,
BROADCAST_BYTE_MULTIPLES
BROADCAST_BYTE_MULTIPLES,
INNER_MOST_DIMS_INFO,
CAN_SCHEDULE_TRANSPOSE,
};

//! Entry type definition class for `DOMAIN_MAP`,
Expand All @@ -45,6 +49,15 @@ class DomainMap {
CompileTimeEntryType::DOMAIN_MAP;
};

//! Entry type definition class for `DOMAIN_MAP`,
//! stores the domain map of a fusion, used by transpose scheduler.
class TransposeDomainMap {
public:
using DataType = pointwise_utils::DomainMap;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::TRANSPOSE_DOMAIN_MAP;
};

//! Entry type definition class for `REFERENCE_TENSORS`,
//! stores the the reference TensorViews used to schedule a fusion.
class ReferenceTensors {
Expand All @@ -54,6 +67,16 @@ class ReferenceTensors {
CompileTimeEntryType::REFERENCE_TENSORS;
};

//! Entry type definition class for `REFERENCE_TENSORS`,
//! stores the the reference TensorViews used to schedule a fusion, used by
//! transpose scheduler.
class ReferenceTensorsForGroups {
public:
using DataType = std::vector<TensorView*>;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::REFERENCE_TENSORS_FOR_GROUPS;
};

//! Entry type definition class for `VECTORIZABLE_INPUTS_AND_OUTPUTS`,
//! stores the vectorizable TensorViews on a fusion's inputs and outputs.
class VectorizableInputsAndOutputs {
Expand Down Expand Up @@ -99,6 +122,16 @@ class PersistentBufferInfo {
CompileTimeEntryType::PERSISTENT_BUFFER_INFO;
};

//! Entry type definition class for `INNER_MOST_DIMS_INFO`,
//! Used in the transpose scheduler to store inner most IterDomains and their
//! position in reference1 of group 1 and group 2
class InnerMostDimInfo {
public:
using DataType = std::vector<int64_t>;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::INNER_MOST_DIMS_INFO;
};

//! Auxiliary data types for `SCOPE_PERSISTENT_FACTOR_INFO` entry type.
using ScopedPersistenceBufferMap = std::unordered_map<Val*, std::vector<bool>>;

Expand Down Expand Up @@ -126,6 +159,15 @@ class BroadcastMultiples {
CompileTimeEntryType::BROADCAST_BYTE_MULTIPLES;
};

//! Entry type definition class for `CAN_SCHEDULE_TRANSPOSE`,
//! stores if the transpose scheduler can scheduler this fusion
class CanScheduleTranspose {
public:
using DataType = bool;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::CAN_SCHEDULE_TRANSPOSE;
};

//! Base abstract class for unified storage in `HeuristicSummary`,
//! each entry in `HeuristicSummary` will be a subclass.
class CompileTimeInfoBase : public PolymorphicBase {
Expand Down
196 changes: 116 additions & 80 deletions torch/csrc/jit/codegen/cuda/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,84 @@ class ReductionScheduler : public SchedulerEntry {
}
};

class TransposeScheduler : public SchedulerEntry {
public:
explicit TransposeScheduler(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr)
: SchedulerEntry(ScheduleHeuristic::Transpose) {
computeHeuristics(fusion, runtime_info, data_cache);
}

static bool canScheduleCompileTime(Fusion* fusion) {
// Temporarily disallow view in transpose scheduler
// TODO Add more testing before enabling
auto view_tvs = scheduler_utils::getViewTVs(fusion);
if (view_tvs.size() > 0) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose, "No support for view op");
return false;
}

if (!hasAtLeastTwoValidGroups(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose,
"cannot find two mismatching inner most dimensions");
return false;
}

// TODO: add support for trivial reduction
auto reduction_ops =
ir_utils::getReductionOps(fusion, false /* ignore_trivial */);

if (!reduction_ops.empty()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose, "no support for reduction ops");
return false;
}

if (hasNonUniqueBcast(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose,
"Broadcasting dimension might be broadcasting to multiple sizes.");
return false;
}

return true;
}

static bool canScheduleRunTime(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr) {
FUSER_PERF_SCOPE("TransposeScheduler::canScheduleRunTime");

auto reason =
getTransposeRuntimeRejectReason(fusion, data_cache, runtime_info);
if (!reason.empty()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose, reason);
return false;
}
return true;
}

void schedule(Fusion* fusion) override {
FUSER_PERF_SCOPE("Schedule Transpose Fusion");
scheduleTranspose(fusion, transposeParams());
}

private:
void computeHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr) {
params_ = getTransposeHeuristics(fusion, runtime_info, data_cache);
TORCH_INTERNAL_ASSERT(params_ != nullptr);
}
};

class PointWiseScheduler : public SchedulerEntry {
public:
explicit PointWiseScheduler(
Expand Down Expand Up @@ -1037,6 +1115,18 @@ class PointWiseScheduler : public SchedulerEntry {
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr) {
auto can_schedule_transpose_entry =
HeuristicSummaryEntry<HeuristicCompileTime::CanScheduleTranspose>(
data_cache, [fusion]() {
return std::make_unique<bool>(
TransposeScheduler::canScheduleCompileTime(fusion));
});
if (can_schedule_transpose_entry.get()) {
auto reason =
getTransposeRuntimeRejectReason(fusion, data_cache, runtime_info);
return !reason.empty();
}

return true;
}

Expand Down Expand Up @@ -1283,81 +1373,6 @@ class PersistentKernelScheduler : public SchedulerEntry {
}
};

class TransposeScheduler : public SchedulerEntry {
public:
explicit TransposeScheduler(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr)
: SchedulerEntry(ScheduleHeuristic::Transpose) {
computeHeuristics(fusion, runtime_info, data_cache);
}

static bool canScheduleCompileTime(Fusion* fusion) {
if (!isOptionEnabled(EnableOption::TransposeScheduler)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose, "not enabled");
return false;
}

// Temporarily disallow view in transpose scheduler
// TODO Add more testing before enabling
auto view_tvs = scheduler_utils::getViewTVs(fusion);
if (view_tvs.size() > 0) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose, "No support for view op");
return false;
}

if (!hasAtLeastTwoValidGroups(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose,
"cannot find two mismatching inner most dimensions");
return false;
}

// TODO: add support for trivial reduction
auto reduction_ops =
ir_utils::getReductionOps(fusion, false /* ignore_trivial */);

if (!reduction_ops.empty()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose, "no support for reduction ops");
return false;
}

if (hasNonUniqueBcast(fusion)) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Transpose,
"Broadcasting dimension might be broadcasting to multiple sizes.");
return false;
}

return true;
}

static bool canScheduleRunTime(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr) {
return true;
}

void schedule(Fusion* fusion) override {
FUSER_PERF_SCOPE("Schedule Transpose Fusion");
scheduleTranspose(fusion, transposeParams());
}

private:
void computeHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr) {
params_ = getTransposeHeuristics(fusion, runtime_info, data_cache);
TORCH_INTERNAL_ASSERT(params_ != nullptr);
}
};

// Schedule Table
const std::vector<ScheduleHeuristic>& all_heuristics() {
static const std::vector<ScheduleHeuristic> hlist = {
Expand Down Expand Up @@ -1550,6 +1565,26 @@ void HeuristicSummary::validate() const {
entry_type_map_.count(EntryType::VECTORIZABLE_INPUTS_AND_OUTPUTS));
TORCH_INTERNAL_ASSERT(
entry_type_map_.count(EntryType::BROADCAST_BYTE_MULTIPLES));
TORCH_INTERNAL_ASSERT(
entry_type_map_.count(EntryType::CAN_SCHEDULE_TRANSPOSE));
auto can_schedule_transpose =
entry_type_map_.at(EntryType::CAN_SCHEDULE_TRANSPOSE)
->as<
CompileTimeInfo<HeuristicCompileTime::CanScheduleTranspose>>()
->get();
if (!*can_schedule_transpose) {
break;
}
}
case ScheduleHeuristic::Transpose: {
TORCH_INTERNAL_ASSERT(
entry_type_map_.count(EntryType::TRANSPOSE_DOMAIN_MAP));
TORCH_INTERNAL_ASSERT(entry_type_map_.count(
EntryType::INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS));
TORCH_INTERNAL_ASSERT(
entry_type_map_.count(EntryType::REFERENCE_TENSORS_FOR_GROUPS));
TORCH_INTERNAL_ASSERT(
entry_type_map_.count(EntryType::INNER_MOST_DIMS_INFO));
break;
}
case ScheduleHeuristic::Reduction: {
Expand Down Expand Up @@ -1579,11 +1614,6 @@ void HeuristicSummary::validate() const {
entry_type_map_.count(EntryType::SCOPE_PERSISTENT_FACTOR_INFO));
break;
}
case ScheduleHeuristic::Transpose: {
TORCH_INTERNAL_ASSERT(entry_type_map_.count(
EntryType::INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS));
break;
}
default:
TORCH_INTERNAL_ASSERT(false, "unknown heuristic");
}
Expand Down Expand Up @@ -1620,7 +1650,10 @@ HeuristicSummaryEntry<EntryClass>::HeuristicSummaryEntry(

// Template instantiation for pre-defined cache entries
template class HeuristicSummaryEntry<HeuristicCompileTime::DomainMap>;
template class HeuristicSummaryEntry<HeuristicCompileTime::TransposeDomainMap>;
template class HeuristicSummaryEntry<HeuristicCompileTime::ReferenceTensors>;
template class HeuristicSummaryEntry<
HeuristicCompileTime::ReferenceTensorsForGroups>;
template class HeuristicSummaryEntry<
HeuristicCompileTime::VectorizableInputsAndOutputs>;
template class HeuristicSummaryEntry<
Expand All @@ -1633,6 +1666,9 @@ template class HeuristicSummaryEntry<
template class HeuristicSummaryEntry<
HeuristicCompileTime::ScopePersistentFactorInfo>;
template class HeuristicSummaryEntry<HeuristicCompileTime::BroadcastMultiples>;
template class HeuristicSummaryEntry<HeuristicCompileTime::InnerMostDimInfo>;
template class HeuristicSummaryEntry<
HeuristicCompileTime::CanScheduleTranspose>;

} // namespace cuda
} // namespace fuser
Expand Down
Loading

0 comments on commit bd93578

Please sign in to comment.