diff --git a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp index 50ac539c7a704..08517c441ebfd 100644 --- a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp @@ -21,43 +21,79 @@ namespace { //! Checks that the current loop nest is not realizing a serial //! broadcast so that each index of producer buffer will only -//! be visited once. -//! TODO: should refactor this utility now to use loop maps in a -//! follow up. +//! be visited once, which is the only case where aggressive +//! inner sharing is valid. +//! bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) { - auto producer_root = - TensorDomain::noReductions(producer->getMaybeRFactorDomain()); - auto consumer_root = - TensorDomain::noReductions(consumer->getMaybeRFactorDomain()); - - if (producer_root.size() != consumer_root.size()) { - // This case would be a single broadcast or a single reduce - // which wouldn't be a broadcast resolution - return false; + //! Note: see issue #1785: + //! serial broadcast resolution doesn't only happen to + //! immediate producers of broadcast ops. We can also have + //! example: + //! T1[I,B] = broadcast(T0[I]]) + //! T3[I,I] = T1[I,B] + T2[I,I] + //! T4[I,I] = T3[I,I] + //! and generates the following loop: + //! alloc T0[4] + //! For i in 0..3 + //! T0[...] = + //! + //! For j in 0...X: + //! alloc T3[4] + //! for k in 0..3: + //! alloc T1[1] + //! T1[0] = T0[k] // <- This is actually a broadcast resolution + //! T3[k] = T1[0] + T2[...] + //! T4[...] = T3[...] + //! + //! In this case we are actually visiting each pixel of T0 in each iteration + //! of the j loop while T1 was the broadcasted tensor causing this reuse. + //! + //! The current version of checking covers this scenario by checking the root + //! ids of the consumer concrete loop id's. Any time a local tensor like T0 + //! appears in a re-use scenario like above, we should see a serial loop id + //! that was derived from some root id that doesn't concretely map to T0's + //! domain. + + // Serial concrete loop id's that cover consumer's iter domain. + std::vector consumer_serial_loop_concrete_ids; + + for (auto consumer_leaf_id : consumer->domain()->domain()) { + auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( + consumer_leaf_id, IdMappingMode::LOOP); + + // Check for any serial loop id with non-trivial extent + if (!concrete_loop_id->isThread() && + !concrete_loop_id->extent()->isOneInt()) { + consumer_serial_loop_concrete_ids.push_back(concrete_loop_id); + } } - std::vector serial_ids; - std::copy_if( - producer->domain()->domain().begin(), - producer->domain()->domain().end(), - std::back_inserter(serial_ids), - [](IterDomain* id) { return !id->isThread(); }); - - auto serial_producer_roots = - InputsOf::outputs(FusionGuard::getCurFusion(), serial_ids); - auto serial_root_id = - ir_utils::filterByType(serial_producer_roots); - std::unordered_set serial_producer_root_set( - serial_root_id.begin(), serial_root_id.end()); - - for (const auto idx : c10::irange(producer_root.size())) { - if (producer_root[idx]->isBroadcast() && - !consumer_root[idx]->isBroadcast()) { - // Check if this broadcast contributed to any serial - // scheduled iterdomains: - if (serial_producer_root_set.count(producer_root[idx])) { - return true; - } + // Collect the root id's that the serial loop iterdomain + // are transformed from. + auto serial_loop_roots = InputsOf::outputs( + FusionGuard::getCurFusion(), consumer_serial_loop_concrete_ids); + + // Collect exact concrete id's in producer's root domain + std::unordered_set producer_exact_concrete_root_ids; + auto producer_root = + TensorDomain::noReductions(producer->getMaybeRFactorDomain()); + std::transform( + producer_root.begin(), + producer_root.end(), + std::inserter( + producer_exact_concrete_root_ids, + producer_exact_concrete_root_ids.begin()), + ir_utils::caMapExactConcreteId); + + // Check if serial loop roots indexes any exact root id's that + // is not within the set of producer's root exact id's. These + // id's will imply that the same producer pixel is accessed + // in multiple iterations of the materialized serial loop. + for (auto serial_loop_root : + ir_utils::filterByType(serial_loop_roots)) { + if (!producer_exact_concrete_root_ids.count( + ir_utils::caMapExactConcreteId(serial_loop_root))) { + return true; } } @@ -998,7 +1034,6 @@ class AllocateReuseModifier { struct InPlaceSharingInfo { bool has_broadcast_between = false; bool has_unsupported_op = false; - bool has_serial_broadcast_resolution_between = false; }; //! Careful heavy check on inner sharing candidates, @@ -1044,13 +1079,6 @@ class AllocateReuseModifier { return false; } - // TODO: blanket disable reuse across broadcast concretization - // to unblock issue for now. - // Should improve the precision of this analysis in a follow up. - if (topo_info.has_serial_broadcast_resolution_between) { - return false; - } - // Get information on the allocated domains of the // two buffers auto& local_alloc_map = GpuLower::current()->localAllocationInfoMap(); @@ -1103,14 +1131,6 @@ class AllocateReuseModifier { info.has_unsupported_op = true; } } - - for (auto in_tv : - ir_utils::filterByType(tv_def->inputs())) { - if (all_used_val_set.count(in_tv) && - isSerialBroadcastResolution(in_tv, tv)) { - info.has_serial_broadcast_resolution_between = true; - } - } } } return info;