Skip to content

Commit

Permalink
Move ConcretizedBroadcastDomains to shared_ptr in GpuLower. (#1988)
Browse files Browse the repository at this point in the history
  • Loading branch information
csarofeen authored Sep 27, 2022
1 parent 8f1c7f5 commit 15f2f6d
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 17 deletions.
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {
compute_at_map_->validateAndPropagatePType();

// Used in parallel dimension map
concretized_broadcast_domains_.build(fusion_);
concretized_broadcast_domains_ =
std::make_shared<const ConcretizedBroadcastDomains>(fusion_);

parallelDimensionMap().build(fusion_);
if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) {
Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/jit/codegen/cuda/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
//! Query if lowering is in progress
static bool hasCurrent();

ConcretizedBroadcastDomains& concretizedBroadcastDomains() {
std::shared_ptr<const ConcretizedBroadcastDomains>
concretizedBroadcastDomains() {
return concretized_broadcast_domains_;
}

Expand Down Expand Up @@ -194,7 +195,8 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
// would be safer to wrap all of these in unique pointers and remove the build
// interface and default constructor. That way they couldn't be accessed
// without being initialized.
ConcretizedBroadcastDomains concretized_broadcast_domains_;
std::shared_ptr<const ConcretizedBroadcastDomains>
concretized_broadcast_domains_;
ThreadPredicateMap thread_pred_map_;
PredicateElimination pred_elimination_;
std::shared_ptr<ComputeAtMap> compute_at_map_;
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/codegen/cuda/lower_sync_information.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void validateParallelizationOfTensor(TensorView* tv) {
// It doesn't matter if this axis is a non-concretized broadcast
// TODO: merging broadcast and non-broadcast
if (axis->isBroadcast() &&
!GpuLower::current()->concretizedBroadcastDomains().isConcretized(
!GpuLower::current()->concretizedBroadcastDomains()->isConcretized(
axis)) {
continue;
}
Expand Down Expand Up @@ -195,7 +195,7 @@ void SyncMap::build(Fusion* fusion) {
(!parallel_bcast_doms.get(consumer_ptype) ||
!GpuLower::current()
->concretizedBroadcastDomains()
.isConcretized(consumer_axis))) {
->isConcretized(consumer_axis))) {
continue;
}

Expand Down Expand Up @@ -421,7 +421,7 @@ void SyncMap::build(Fusion* fusion) {
.redundant_types;

if (p_id->isBroadcast() &&
GpuLower::current()->concretizedBroadcastDomains().isConcretized(
GpuLower::current()->concretizedBroadcastDomains()->isConcretized(
p_id) &&
producer->getMemoryType() == MemoryType::Shared &&
redundant_preds.hasTID()) {
Expand All @@ -436,7 +436,7 @@ void SyncMap::build(Fusion* fusion) {
(!parallel_bcast_doms.get(producer_ptype) ||
!GpuLower::current()
->concretizedBroadcastDomains()
.isConcretized(p_id))) {
->isConcretized(p_id))) {
continue;
}

Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {
id_reductions.set(id->getParallelType());
}
if (id->isBroadcast() &&
GpuLower::current()->concretizedBroadcastDomains().isConcretized(
GpuLower::current()->concretizedBroadcastDomains()->isConcretized(
id)) {
id_bcasts.set(id->getParallelType());
}
Expand Down Expand Up @@ -575,7 +575,8 @@ ParallelTypeBitmap ThreadPredicateMap::getParallelBroadcastDomains(

for (auto id : iter_domains) {
if (!id->isBroadcast() ||
!GpuLower::current()->concretizedBroadcastDomains().isConcretized(id)) {
!GpuLower::current()->concretizedBroadcastDomains()->isConcretized(
id)) {
continue;
}
if (id->isBlockDim() || (!output_smem && id->isThreadDim())) {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace jit {
namespace fuser {
namespace cuda {

void ConcretizedBroadcastDomains::build(Fusion* fusion) {
ConcretizedBroadcastDomains::ConcretizedBroadcastDomains(Fusion* fusion) {
exact_map_ = std::make_unique<ExactRootDomainMap>(fusion);

// Initialize the origin map with input broadcast domains
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ namespace cuda {
//! domains are marked as concretized.
class TORCH_CUDA_CU_API ConcretizedBroadcastDomains : private IterVisitor {
public:
void build(Fusion* fusion);
ConcretizedBroadcastDomains() = delete;
ConcretizedBroadcastDomains(Fusion* fusion);

//! Is a domain concretized?
bool isConcretized(IterDomain* id) const;
Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,8 +789,7 @@ static bool checkPatternEquivalence(
// being broadcasted to one size multiple times or different sizes. This is a
// hard to optimize problem and likely indicates we shouldn't be fusing.
bool hasNonUniqueBcast(Fusion* fusion) {
ConcretizedBroadcastDomains concretize_info;
concretize_info.build(fusion);
ConcretizedBroadcastDomains concretize_info(fusion);

for (auto tv : ir_utils::allTvs(fusion)) {
for (auto id : tv->getRootDomain()) {
Expand Down
7 changes: 3 additions & 4 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20889,9 +20889,9 @@ TEST_F(NVFuserTest, FusionBroadcastConcretization1_CUDA) {
}

GpuLower gpulw(&fusion);
TORCH_CHECK(!gpulw.concretizedBroadcastDomains().isConcretized(
TORCH_CHECK(!gpulw.concretizedBroadcastDomains()->isConcretized(
loweredTv(tv4, gpulw)->axis(1)));
TORCH_CHECK(gpulw.concretizedBroadcastDomains().isConcretized(
TORCH_CHECK(gpulw.concretizedBroadcastDomains()->isConcretized(
loweredTv(tv7, gpulw)->axis(1)));

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
Expand Down Expand Up @@ -21079,8 +21079,7 @@ TEST_F(NVFuserTest, FusionBroadcastConcretization5_CUDA) {
auto tvs3 = Welford(tv17, {1});
fusion.addOutput(tvs3.avg);

ConcretizedBroadcastDomains bcast_concretization_info;
bcast_concretization_info.build(&fusion);
ConcretizedBroadcastDomains bcast_concretization_info(&fusion);

TORCH_CHECK(
bcast_concretization_info.maybeNonUniquelyConcretized(tv5->axis(1)),
Expand Down

0 comments on commit 15f2f6d

Please sign in to comment.