Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ inline void inferGroupInputs(
}
} // namespace

FusionSegmentRuntime::SchedulerEntryPtr SegmentedFusion::makeSchedulerEntry(
FusionKernelRuntime::SchedulerEntryPtr SegmentedFusion::makeSchedulerEntry(
SegmentedGroup* sg,
ExpressionEvaluator& ee) {
ExpressionEvaluator local_ee(&fusion_);
Expand All @@ -1268,12 +1268,12 @@ FusionSegmentRuntime::SchedulerEntryPtr SegmentedFusion::makeSchedulerEntry(
return SchedulerEntry::makeEntry(sg->heuristic(), &fusion_, local_ee);
}

std::unique_ptr<SegmentHeuristics> SegmentedFusion::makeHeuristics(
std::unique_ptr<FusionHeuristics> SegmentedFusion::makeHeuristics(
const at::ArrayRef<IValue>& inputs) {
auto ret = std::make_unique<SegmentHeuristics>();
auto ret = std::make_unique<FusionHeuristics>();
auto evaluator = executor_utils::bindFusionInputs(inputs, &fusion_);
for (auto g : groups()) {
ret->emplace_back(makeSchedulerEntry(g, evaluator));
ret->emplaceBack(makeSchedulerEntry(g, evaluator));
}
return ret;
}
Expand Down
44 changes: 34 additions & 10 deletions torch/csrc/jit/codegen/cuda/fusion_segmenter.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class TORCH_CUDA_CU_API SegmentedGroup {
private:
friend class SegmentCandidateFinder;
friend class SegmentedFusion;
friend class FusionSegmentRuntime;
friend class FusionKernelRuntime;

//! unique identifier of group in the segmented fusion
int group_id_ = -1;
Expand Down Expand Up @@ -174,23 +174,47 @@ class TORCH_CUDA_CU_API SegmentedGroup {

std::ostream& operator<<(std::ostream& os, const SegmentedGroup* group);

//! Auxiliary class for managing a list of heuristics instances for the
//! Segmented Groups
class TORCH_CUDA_CU_API SegmentHeuristics {
using SchedulerEntryPtr = std::unique_ptr<SchedulerEntry>;
//! Auxiliary class for storing heuristics. The managed data is either
//! a single scheduler entry for complete fusion,
//! or a vector of schedulers, one for each segment, for segmented fusion.
class TORCH_CUDA_CU_API FusionHeuristics {
using SchedulerEntryOwningPtr = std::unique_ptr<SchedulerEntry>;

public:
explicit SegmentHeuristics() = default;
void emplace_back(SchedulerEntryPtr&& pt) {
//! Constructor for segmented fusion case. Created with empty list and
//! uses emplaceBack for inserting heuristics in order
explicit FusionHeuristics() = default;

//! Constructor for complete fusion case, generates the scheduler entry
//! for the fusion owning the given expression
explicit FusionHeuristics(
ScheduleHeuristic schedule_heuristic,
ExpressionEvaluator& expr_eval) {
heuristics_.emplace_back(SchedulerEntry::makeEntry(
schedule_heuristic, expr_eval.fusion(), expr_eval));
is_segmented_ = false;
}

//! Place a scheduler entry on the list. Applies to segmented fusion only.
void emplaceBack(SchedulerEntryOwningPtr&& pt) {
TORCH_INTERNAL_ASSERT(is_segmented_);
heuristics_.emplace_back(std::move(pt));
}

const std::vector<SchedulerEntryPtr>& heuristics() const {
//! Returns list of schedulers for a segmneted fusion.
const std::vector<SchedulerEntryOwningPtr>& heuristicsList() const {
return heuristics_;
}

//! Returns the single scheduler for a complete fusion.
SchedulerEntry* singleHeuristics() {
TORCH_INTERNAL_ASSERT(!is_segmented_);
return heuristics_.begin()->get();
}

private:
std::vector<SchedulerEntryPtr> heuristics_;
std::vector<SchedulerEntryOwningPtr> heuristics_;
bool is_segmented_ = true;
};

//! Exported Interface for representing segmented fusion graph
Expand Down Expand Up @@ -237,7 +261,7 @@ class TORCH_CUDA_CU_API SegmentedFusion {
std::unique_ptr<Fusion> makeFusion(SegmentedGroup* sg);

//! Make heuristics for all groups in this segmented fusion
std::unique_ptr<SegmentHeuristics> makeHeuristics(
std::unique_ptr<FusionHeuristics> makeHeuristics(
const at::ArrayRef<IValue>& inputs);

//! Inline Debug print for segmented fusion
Expand Down
Loading