Skip to content

Commit

Permalink
cleanup (#1997)
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Sep 27, 2022
1 parent 4cbe0db commit 482386c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 67 deletions.
47 changes: 16 additions & 31 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,8 @@ void IndexCompute::handle(Merge* merge) {

// When the reference has halo extent for inner_id, that extent needs to
// be used to un-merge
if (reference_halo_extent_map_.find(inner_id) !=
reference_halo_extent_map_.end()) {
inner_extent = reference_halo_extent_map_[inner_id];
if (halo_extent_map_.find(inner_id) != halo_extent_map_.end()) {
inner_extent = halo_extent_map_[inner_id];
}

const auto outer_extent = getExtent(outer_id);
Expand Down Expand Up @@ -588,7 +587,7 @@ IndexCompute::IndexCompute(
std::unordered_set<IterDomain*> zero_domains,
std::unordered_set<IterDomain*> zero_merged_in,
std::unordered_set<IterDomain*> preferred_paths,
std::unordered_map<IterDomain*, Val*> reference_halo_extent_map)
std::unordered_map<IterDomain*, Val*> halo_extent_map)
: IndexCompute(
_td,
std::move(initial_index_map),
Expand All @@ -601,7 +600,7 @@ IndexCompute::IndexCompute(
std::vector<bool>(_td->getMaybeRFactorDomain().size(), false),
{}),
std::move(preferred_paths),
std::move(reference_halo_extent_map)) {}
std::move(halo_extent_map)) {}

IndexCompute::IndexCompute(
const TensorDomain* _td,
Expand All @@ -611,14 +610,14 @@ IndexCompute::IndexCompute(
std::unordered_set<IterDomain*> zero_merged_in,
const ContigIDs& contig_finder,
std::unordered_set<IterDomain*> preferred_paths,
std::unordered_map<IterDomain*, Val*> reference_halo_extent_map)
std::unordered_map<IterDomain*, Val*> halo_extent_map)
: td_(_td),
index_map_(std::move(initial_index_map)),
extent_map_(std::move(extent_map)),
zero_domains_(std::move(zero_domains)),
zero_merged_in_(std::move(zero_merged_in)),
preferred_paths_(std::move(preferred_paths)),
reference_halo_extent_map_(std::move(reference_halo_extent_map)) {
halo_extent_map_(std::move(halo_extent_map)) {
FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute");

// Make sure we recompute any indices we can that map to a contiguous access
Expand All @@ -641,11 +640,11 @@ IndexCompute::IndexCompute(
std::unordered_map<IterDomain*, Val*> initial_index_map,
std::unordered_set<IterDomain*> zero_domains,
std::unordered_set<IterDomain*> preferred_paths,
std::unordered_map<IterDomain*, Val*> reference_halo_extent_map)
std::unordered_map<IterDomain*, Val*> halo_extent_map)
: index_map_(std::move(initial_index_map)),
zero_domains_(std::move(zero_domains)),
preferred_paths_(std::move(preferred_paths)),
reference_halo_extent_map_(std::move(reference_halo_extent_map)) {
halo_extent_map_(std::move(halo_extent_map)) {
FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute");
concrete_id_pass_ = true;
swizzle_mode_ = SwizzleMode::Loop;
Expand Down Expand Up @@ -790,15 +789,14 @@ bool IndexCompute::isZero(IterDomain* id) const {
IndexCompute IndexCompute::updateIndexCompute(
const TensorDomain* new_td,
const std::unordered_map<IterDomain*, IterDomain*>& id_map,
const ContigIDs& contig_finder,
const std::unordered_map<IterDomain*, Val*>& reference_halo_extent_map)
const {
const ContigIDs& contig_finder) const {
FUSER_PERF_SCOPE("GpuLower::Lower::updateIndexCompute");

std::unordered_map<IterDomain*, Val*> updated_index_map;
std::unordered_map<IterDomain*, Val*> updated_extent_map;
std::unordered_set<IterDomain*> updated_zero_domains;
std::unordered_set<IterDomain*> updated_zero_merged_in;
std::unordered_map<IterDomain*, Val*> updated_halo_extent_map;

for (auto id_entry : id_map) {
IterDomain* prev_id = id_entry.first;
Expand All @@ -817,6 +815,11 @@ IndexCompute IndexCompute::updateIndexCompute(
if (zero_merged_in_.find(prev_id) != zero_merged_in_.end()) {
updated_zero_merged_in.emplace(new_id);
}

auto halo_extent_it = halo_extent_map_.find(prev_id);
if (halo_extent_it != halo_extent_map_.end()) {
updated_halo_extent_map[new_id] = halo_extent_it->second;
}
}

IndexCompute updated_index_compute(
Expand All @@ -827,25 +830,7 @@ IndexCompute IndexCompute::updateIndexCompute(
updated_zero_merged_in,
contig_finder,
{},
reference_halo_extent_map);

if (concrete_id_pass_) {
// This should be the same behavior as with a reference tensor
// created, since originally halo was pulled through exact
// ca mapping and in the concrete_id_pass case, the id_map
// also represents exact ca mapping.
// TODO: might need to re-visit pathological cases when we may
// need to traverse and propagate halo info again in here.
for (auto id_entry : id_map) {
IterDomain* prev_id = id_entry.first;
IterDomain* new_id = id_entry.second;
auto halo_extent_it = reference_halo_extent_map_.find(prev_id);
if (halo_extent_it != reference_halo_extent_map_.end()) {
updated_index_compute.reference_halo_extent_map_[new_id] =
halo_extent_it->second;
}
}
}
updated_halo_extent_map);

updated_index_compute.run();

Expand Down
14 changes: 5 additions & 9 deletions torch/csrc/jit/codegen/cuda/index_compute.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/reference_tensor.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>

#include <unordered_map>
Expand Down Expand Up @@ -135,9 +134,8 @@ class IndexCompute : public BackwardVisitor {
// if there's an option
std::unordered_set<IterDomain*> preferred_paths_;

// Map from IterDomains to halo-extended extents in corresponding
// reference tensor
std::unordered_map<IterDomain*, Val*> reference_halo_extent_map_;
// Map from IterDomains to halo-extended extents
std::unordered_map<IterDomain*, Val*> halo_extent_map_;

// Temporary flag which tells IndexCompute to use concrete id's from the exact
// map rather than the actual IDs used in the ID expressions.
Expand Down Expand Up @@ -189,7 +187,7 @@ class IndexCompute : public BackwardVisitor {
std::unordered_set<IterDomain*> zero_domains,
std::unordered_set<IterDomain*> _zero_merged_in,
std::unordered_set<IterDomain*> preferred_paths = {},
std::unordered_map<IterDomain*, Val*> reference_halo_extent_map = {});
std::unordered_map<IterDomain*, Val*> halo_extent_map = {});

IndexCompute(
const TensorDomain* _td,
Expand All @@ -199,7 +197,7 @@ class IndexCompute : public BackwardVisitor {
std::unordered_set<IterDomain*> _zero_merged_in,
const ContigIDs& contig_finder,
std::unordered_set<IterDomain*> preferred_paths = {},
std::unordered_map<IterDomain*, Val*> reference_halo_extent_map = {});
std::unordered_map<IterDomain*, Val*> halo_extent_map = {});

// Entry point used for using concrete id based traversal. This traversal is
// assumed to start at leaf IDs provided by initial_index_map.
Expand All @@ -214,9 +212,7 @@ class IndexCompute : public BackwardVisitor {
IndexCompute updateIndexCompute(
const TensorDomain* new_td,
const std::unordered_map<IterDomain*, IterDomain*>& id_map,
const ContigIDs& contig_finder,
const std::unordered_map<IterDomain*, Val*>& reference_halo_extent_map =
{}) const;
const ContigIDs& contig_finder) const;

// Interface to run index traversal through loop indexing analysis result to
// be used with the entry point for concrete id based traversal.
Expand Down
27 changes: 0 additions & 27 deletions torch/csrc/jit/codegen/cuda/reference_tensor.h

This file was deleted.

0 comments on commit 482386c

Please sign in to comment.