diff --git a/examples/madness/mra-device/mrattg-device.cc b/examples/madness/mra-device/mrattg-device.cc index b0b069d4b..1c460a97b 100644 --- a/examples/madness/mra-device/mrattg-device.cc +++ b/examples/madness/mra-device/mrattg-device.cc @@ -122,13 +122,14 @@ auto make_project( } template -static auto select_compress_send(const mra::Key& parent, Value&& value, +static auto select_compress_send(const mra::Key& key, Value&& value, std::size_t child_idx, std::index_sequence) { if (child_idx == I) { - return ttg::device::send(parent, std::forward(value)); + std::cout << "key " << key << " sends to parent " << key.parent() << " input " << I << std::endl; + return ttg::device::send(key.parent(), std::forward(value)); } else if constexpr (sizeof...(Is) > 0){ - return select_compress_send(parent, std::forward(value), child_idx, std::index_sequence{}); + return select_compress_send(key, std::forward(value), child_idx, std::index_sequence{}); } /* if we get here we messed up */ throw std::runtime_error("Mismatching number of children!"); @@ -139,7 +140,10 @@ static auto select_compress_send(const mra::Key& parent, Value&& value, * even though it will not actually perform any computation */ template static ttg::device::Task do_send_leafs_up(const mra::Key& key, const mra::FunctionReconstructedNode& node) { - co_await select_compress_send(key.parent(), node, key.childindex(), std::index_sequence::num_children>{}); + /* drop all inputs from nodes that are not leafs, they will be upstreamed by compress */ + if (!node.has_children()) { + co_await select_compress_send(key, node, key.childindex(), std::make_index_sequence::num_children>{}); + } } @@ -147,10 +151,10 @@ static ttg::device::Task do_send_leafs_up(const mra::Key& key, const mra:: template static auto make_compress( const mra::FunctionData& functiondata, - ttg::Edge, mra::FunctionReconstructedNode> in, - ttg::Edge, mra::FunctionCompressedNode> out) { + ttg::Edge, mra::FunctionReconstructedNode>& in, + ttg::Edge, mra::FunctionCompressedNode>& out) +{ static_assert(NDIM == 3); // TODO: worth fixing? - ttg::Edge, mra::FunctionReconstructedNode> recur("recur"); constexpr const std::size_t num_children = mra::Key::num_children; // creates the right number of edges for nodes to flow from send_leafs_up to compress @@ -159,9 +163,10 @@ static auto make_compress( return ttg::edges((Is, ttg::Edge, mra::FunctionReconstructedNode>{})...); }; auto send_to_compress_edges = create_edges(std::make_index_sequence{}); + /* append out edge to set of edges */ + auto compress_out_edges = std::tuple_cat(send_to_compress_edges, std::make_tuple(out)); /* use the tuple variant to handle variable number of inputs while suppressing the output tuple */ - auto do_compress = [&]// - (const mra::Key& key, + auto do_compress = [&](const mra::Key& key, //const std::tuple& input_frns const mra::FunctionReconstructedNode &in0, const mra::FunctionReconstructedNode &in1, @@ -174,6 +179,7 @@ static auto make_compress( //const typename ::detail::tree_types::compress_in_type& in, //typename ::detail::tree_types::compress_out_type& out) { constexpr const auto num_children = mra::Key::num_children; + constexpr const auto out_terminal_id = num_children; auto K = in0.coeffs.dim(0); mra::FunctionCompressedNode result(key, K); // The eventual result auto& d = result.coeffs; @@ -221,24 +227,27 @@ static auto make_compress( } // Recur up + std::cout << "compress key " << key << " parent " << key.parent() << " level " << key.level() << std::endl; if (key.level() > 0) { p.sum = tmp[num_children] + sumsq; // result sumsq is last element in sumsqs // will not return co_await ttg::device::forward( // select to which child of our parent we send - ttg::device::send<0>(key, std::move(p)), + //ttg::device::send<0>(key, std::move(p)), + select_compress_send(key, std::move(p), key.childindex(), std::make_index_sequence{}), // Send result to output tree - ttg::device::send<1>(key, std::move(result))); + ttg::device::send(key, std::move(result))); } else { std::cout << "At root of compressed tree: total normsq is " << sumsq + d_sumsq << std::endl; co_await ttg::device::forward( // Send result to output tree - ttg::device::send<1>(key, std::move(result))); + ttg::device::send(key, std::move(result))); } }; - return std::make_tuple(ttg::make_tt(&do_send_leafs_up, edges(ttg::fuse(recur, in)), send_to_compress_edges, "send_leaves_up"), - ttg::make_tt(std::move(do_compress), send_to_compress_edges, edges(recur,out), "do_compress")); + ttg::Edge, mra::FunctionReconstructedNode> recur("recur"); + return std::make_tuple(ttg::make_tt(&do_send_leafs_up, edges(in), send_to_compress_edges, "send_leaves_up"), + ttg::make_tt(std::move(do_compress), send_to_compress_edges, compress_out_edges, "do_compress")); } template