diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index edff04b201123..448ca865e07fe 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -13936,6 +13936,43 @@ TEST(NVFuserTest, FusionIssue728_CUDA) { "Only tv3 should be included"); } +TEST(NVFuserTest, FusionIssue757_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = makeSymbolicTensor(2); + fusion.addInput(tv3); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + tv1->computeAt(tv4, -1); + + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + + int numel_x = 650; + int numel_y = 102; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({numel_x, numel_y}, options); + at::Tensor t3 = at::randn({numel_x, numel_y}, options); + std::vector inputs = {t0, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion(inputs); + + auto t1 = t0.sum({1}); + auto t2 = t1.unsqueeze(-1).expand({numel_x, numel_y}); + auto t4 = t2 + t3; + + testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 0a1fbffc9cf5c..55f88fc590711 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -357,18 +357,15 @@ void ComputeAtMap::build(Fusion* fusion, GpuLower* gpu_lower) { TORCH_INTERNAL_ASSERT( concrete_id != nullptr, "Could not concretize an IterDomain set."); - // If parallel mode, parallelize the the concrete id - // TODO: Would be good to simply keep a parallelization map and make lookups - // to it through lowering. - if (mapping_mode_ == MappingMode::PARALLEL) { - auto parallel_map_it = parallel_type_map_.find(set); - if (parallel_map_it != parallel_type_map_.end()) { - concrete_id->parallelize(parallel_map_it->second); - } - } - for (auto id : *set) { concrete_id_map_[id] = concrete_id; + if (mapping_mode_ == MappingMode::PARALLEL) { + auto parallel_map_it = parallel_type_map_.find(set); + // Parallelize all IterDomains to simplify lowering and codegen + if (parallel_map_it != parallel_type_map_.end()) { + id->parallelize(parallel_map_it->second); + } + } } } diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index ea2b463db399a..b893f81d3040d 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -97,14 +97,12 @@ IterDomain::IterDomain( setName(iter_domain->name()); } +//! Note that the parallel dimension, if available, may be different +//! from the actual extent of this IterDomain as the parallel +//! dimension is determined by the largest extent of IterDomains +//! sharing the same loop. Val* IterDomain::extent() const { TORCH_INTERNAL_ASSERT(extent_ != nullptr); - if (isThread()) { - if (extent_->isScalar() && extent_->isConst()) { - return extent_; - } - return NamedScalar::getParallelDim(parallelType()); - } return extent_; }