diff --git a/CMakeLists.txt b/CMakeLists.txt index 441c1306..c589f9ec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -111,7 +111,6 @@ add_library(larch src/mutation_annotated_dag.cpp src/node_label.cpp src/node_storage.cpp - src/node.cpp src/post_order_iterator.cpp src/pre_order_iterator.cpp) larch_compile_opts(larch) diff --git a/include/impl/node_impl.hpp b/include/impl/node_impl.hpp index 3e32c8b8..bece7149 100644 --- a/include/impl/node_impl.hpp +++ b/include/impl/node_impl.hpp @@ -1,6 +1,86 @@ // Functions defined here are documented where declared in `include/node.hpp` #include +template +NodeView::NodeView(T dag, NodeId id) : dag_{dag}, id_{id} { + static_assert(std::is_same_v or std::is_same_v); + Assert(id.value != NoId); + Assert(id.value < dag_.nodes_.size()); +} + +template +NodeView::operator Node() const { + return {dag_, id_}; +} + +template +NodeView::operator NodeId() const { + return id_; +} + +template +T NodeView::GetDAG() const { + return dag_; +} + +template +NodeId NodeView::GetId() const { + return id_; +} + +template +typename NodeView::EdgeType NodeView::GetSingleParent() const { + Assert(GetParents().size() == 1); + return *GetParents().begin(); +} + +template +bool NodeView::IsRoot() const { + return GetStorage().GetParents().empty(); +} + +template +bool NodeView::IsLeaf() const { + if (GetClades().empty()) { + return true; + } + auto children = GetChildren(); + return children.begin() == children.end(); +} + +template +void NodeView::AddParentEdge(Edge edge) const { + if constexpr (is_mutable) { + GetStorage().AddEdge(edge.GetClade(), edge.GetId(), false); + } +} + +template +void NodeView::AddChildEdge(Edge edge) const { + if constexpr (is_mutable) { + GetStorage().AddEdge(edge.GetClade(), edge.GetId(), true); + } +} + +template +void NodeView::RemoveParentEdge(Edge edge) const { + if constexpr (is_mutable) { + GetStorage().RemoveEdge(edge, false); + } +} + +template +const std::optional& NodeView::GetSampleId() const { + return GetStorage().GetSampleId(); +} + +template +void NodeView::SetSampleId(std::optional&& sample_id) { + if constexpr (is_mutable) { + GetStorage().SetSampleId(std::forward>(sample_id)); + } +} + template auto NodeView::GetParents() const { return GetStorage().GetParents() | Transform::ToEdges(dag_); diff --git a/include/impl/subtree_weight_impl.hpp b/include/impl/subtree_weight_impl.hpp index edd775dc..33341ed4 100644 --- a/include/impl/subtree_weight_impl.hpp +++ b/include/impl/subtree_weight_impl.hpp @@ -1,4 +1,5 @@ #include +#include template SubtreeWeight::SubtreeWeight(const MADAG& dag) @@ -54,20 +55,35 @@ MADAG SubtreeWeight::TrimToMinWeight(WeightOps&& weight_ops) { template std::pair> SubtreeWeight::SampleTree( WeightOps&& weight_ops) { - MADAG result{dag_.GetReferenceSequence()}; - std::vector result_dag_ids; - - ExtractTree( - dag_, dag_.GetDAG().GetRoot(), std::forward(weight_ops), - [this](Node node, CladeIdx clade_idx) { - auto clade = node.GetClade(clade_idx); - Assert(not clade.empty()); - std::uniform_int_distribution distribuition{0, clade.size() - 1}; - return clade.at(distribuition(random_generator_)); - }, - result, result_dag_ids); + return SampleTreeImpl(std::forward(weight_ops), [](auto clade) { + return std::uniform_int_distribution{0, clade.size() - 1}; + }); +} - return {std::move(result), std::move(result_dag_ids)}; +struct TreeCount; +template +std::pair> SubtreeWeight::UniformSampleTree( + WeightOps&& weight_ops) { + static_assert(std::is_same_v, TreeCount>, + "UniformSampleTree needs TreeCount"); + // Ensure cache is filled + ComputeWeightBelow(dag_.GetDAG().GetRoot(), std::forward(weight_ops)); + return SampleTreeImpl( + std::forward(weight_ops), [this, &weight_ops](auto clade) { + std::vector probabilities; + typename WeightOps::Weight sum{}; + for (NodeId child : clade | Transform::GetChild()) { + sum += cached_weights_.at(child.value).value(); + } + if (sum > 0) { + for (NodeId child : clade | Transform::GetChild()) { + probabilities.push_back( + static_cast(cached_weights_.at(child.value).value() / sum)); + } + } + return std::discrete_distribution{probabilities.begin(), + probabilities.end()}; + }); } template @@ -94,6 +110,25 @@ typename WeightOps::Weight SubtreeWeight::CladeWeight( return clade_result.first; } +template +template +std::pair> SubtreeWeight::SampleTreeImpl( + WeightOps&& weight_ops, DistributionMaker&& distribution_maker) { + MADAG result{dag_.GetReferenceSequence()}; + std::vector result_dag_ids; + + ExtractTree( + dag_, dag_.GetDAG().GetRoot(), std::forward(weight_ops), + [this, &distribution_maker](Node node, CladeIdx clade_idx) { + auto clade = node.GetClade(clade_idx); + Assert(not clade.empty()); + return clade.at(distribution_maker(clade)(random_generator_)); + }, + result, result_dag_ids); + + return {std::move(result), std::move(result_dag_ids)}; +} + template template void SubtreeWeight::ExtractTree(const MADAG& input_dag, Node node, @@ -136,10 +171,10 @@ void SubtreeWeight::ExtractTree(const MADAG& input_dag, Node node, for (auto node : result.GetDAG().GetNodes()) { size_t idx = node.GetId().value; - std::optional old_sample_id = + const std::optional& old_sample_id = input_dag.GetDAG().GetNodes().at(idx).GetSampleId(); - if (node.IsLeaf() and (bool) old_sample_id) { - node.SetSampleId(old_sample_id); + if (node.IsLeaf() and old_sample_id.has_value()) { + node.SetSampleId(std::optional{old_sample_id}); } } } diff --git a/include/leaf_set.hpp b/include/leaf_set.hpp index 65ea59a9..f8ae7697 100644 --- a/include/leaf_set.hpp +++ b/include/leaf_set.hpp @@ -1,8 +1,8 @@ #pragma once #include "common.hpp" +#include "compact_genome.hpp" -class CompactGenome; class NodeLabel; /** @@ -27,9 +27,9 @@ class LeafSet { LeafSet(Node node, const std::vector& labels, std::vector& computed_leafsets); - LeafSet(std::vector>&& clades); + inline LeafSet(std::vector>&& clades); - bool operator==(const LeafSet& rhs) const noexcept; + inline bool operator==(const LeafSet& rhs) const noexcept; [[nodiscard]] size_t Hash() const noexcept; @@ -43,8 +43,8 @@ class LeafSet { const std::vector>& GetClades() const; private: - static size_t ComputeHash( - const std::vector>& clades); + inline static size_t ComputeHash( + const std::vector>& clades) noexcept; }; template <> @@ -58,3 +58,22 @@ struct std::equal_to { return lhs == rhs; } }; + +bool LeafSet::operator==(const LeafSet& rhs) const noexcept { + return clades_ == rhs.clades_; +} + +LeafSet::LeafSet(std::vector>&& clades) + : clades_{std::forward>>(clades)}, + hash_{ComputeHash(clades_)} {} + +size_t LeafSet::ComputeHash( + const std::vector>& clades) noexcept { + size_t hash = 0; + for (auto& clade : clades) { + for (auto leaf : clade) { + hash = HashCombine(hash, leaf->Hash()); + } + } + return hash; +} \ No newline at end of file diff --git a/include/node.hpp b/include/node.hpp index 6e340211..cac6f14e 100644 --- a/include/node.hpp +++ b/include/node.hpp @@ -26,14 +26,14 @@ class NodeView { constexpr static const bool is_mutable = std::is_same_v; using NodeType = std::conditional_t; using EdgeType = std::conditional_t; - NodeView(T dag, NodeId id); - operator Node() const; - operator NodeId() const; + inline NodeView(T dag, NodeId id); + inline operator Node() const; + inline operator NodeId() const; /** * Return DAG-like object containing this node */ - T GetDAG() const; - NodeId GetId() const; + inline T GetDAG() const; + inline NodeId GetId() const; /** * Return a range containing parent Edge objects */ @@ -50,7 +50,7 @@ class NodeView { /** * Return the count of child clades */ - size_t GetCladesCount() const; + inline size_t GetCladesCount() const; /** * Return a range containing child Edges */ @@ -58,21 +58,21 @@ class NodeView { /** * Return a single parent edge of this node */ - EdgeType GetSingleParent() const; + inline EdgeType GetSingleParent() const; /** * Checks if node has no parents */ - bool IsRoot() const; + inline bool IsRoot() const; /** * Checks if node has no children */ - bool IsLeaf() const; - void AddParentEdge(Edge edge) const; - void AddChildEdge(Edge edge) const; - void RemoveParentEdge(Edge edge) const; + inline bool IsLeaf() const; + inline void AddParentEdge(Edge edge) const; + inline void AddChildEdge(Edge edge) const; + inline void RemoveParentEdge(Edge edge) const; - const std::optional GetSampleId() const; - void SetSampleId(std::optional sample_id); + inline const std::optional& GetSampleId() const; + inline void SetSampleId(std::optional&& sample_id); private: auto& GetStorage() const; diff --git a/include/node_storage.hpp b/include/node_storage.hpp index c19f0b5d..552e364f 100644 --- a/include/node_storage.hpp +++ b/include/node_storage.hpp @@ -18,8 +18,8 @@ class NodeStorage { */ const std::vector>& GetClades() const; - const std::optional GetSampleId() const; - void SetSampleId(std::optional sample_id); + const std::optional& GetSampleId() const; + void SetSampleId(std::optional&& sample_id); /** * Remove all parent and child edges diff --git a/include/subtree_weight.hpp b/include/subtree_weight.hpp index f89acad0..d6a5508f 100644 --- a/include/subtree_weight.hpp +++ b/include/subtree_weight.hpp @@ -44,10 +44,17 @@ class SubtreeWeight { [[nodiscard]] std::pair> SampleTree( WeightOps&& weight_ops); + [[nodiscard]] std::pair> UniformSampleTree( + WeightOps&& weight_ops); + private: template typename WeightOps::Weight CladeWeight(CladeRange&& clade, WeightOps&& weight_ops); + template + [[nodiscard]] std::pair> SampleTreeImpl( + WeightOps&& weight_ops, DistributionMaker&& distribution_maker); + template void ExtractTree(const MADAG& input_dag, Node node, WeightOps&& weight_ops, EdgeSelector&& edge_selector, MADAG& result, diff --git a/src/dag_loader.cpp b/src/dag_loader.cpp index 9482f476..1a231392 100644 --- a/src/dag_loader.cpp +++ b/src/dag_loader.cpp @@ -99,7 +99,7 @@ MADAG LoadTreeFromProtobuf(std::string_view path, std::string_view reference_seq result.BuildConnections(); for (auto node : result.GetDAG().GetNodes()) { if (node.IsLeaf()) { - node.SetSampleId(seq_ids[node.GetId().value]); + node.SetSampleId(std::move(seq_ids[node.GetId().value])); } } diff --git a/src/leaf_set.cpp b/src/leaf_set.cpp index 6c9fa477..faeb3e71 100644 --- a/src/leaf_set.cpp +++ b/src/leaf_set.cpp @@ -6,7 +6,6 @@ #include #include "dag.hpp" -#include "compact_genome.hpp" #include "node_label.hpp" const LeafSet* LeafSet::Empty() { @@ -42,13 +41,6 @@ LeafSet::LeafSet(Node node, const std::vector& labels, }()}, hash_{ComputeHash(clades_)} {} -LeafSet::LeafSet(std::vector>&& clades) - : clades_{clades}, hash_{ComputeHash(clades_)} {} - -bool LeafSet::operator==(const LeafSet& rhs) const noexcept { - return clades_ == rhs.clades_; -} - size_t LeafSet::Hash() const noexcept { return hash_; } auto LeafSet::begin() const -> decltype(clades_.begin()) { return clades_.begin(); } @@ -69,14 +61,3 @@ std::vector LeafSet::ToParentClade() const { const std::vector>& LeafSet::GetClades() const { return clades_; } - -size_t LeafSet::ComputeHash( - const std::vector>& clades) { - size_t hash = 0; - for (auto& clade : clades) { - for (auto leaf : clade) { - hash = HashCombine(hash, leaf->Hash()); - } - } - return hash; -} diff --git a/src/node.cpp b/src/node.cpp deleted file mode 100644 index ea008d49..00000000 --- a/src/node.cpp +++ /dev/null @@ -1,83 +0,0 @@ -#include "node.hpp" - -#include "dag.hpp" - -template -NodeView::NodeView(T dag, NodeId id) : dag_{dag}, id_{id} { - static_assert(std::is_same_v or std::is_same_v); - Assert(id.value != NoId); - Assert(id.value < dag_.nodes_.size()); -} - -template -NodeView::operator Node() const { - return {dag_, id_}; -} - -template -NodeView::operator NodeId() const { - return id_; -} - -template -T NodeView::GetDAG() const { - return dag_; -} - -template -NodeId NodeView::GetId() const { - return id_; -} - -template -typename NodeView::EdgeType NodeView::GetSingleParent() const { - Assert(GetParents().size() == 1); - return *GetParents().begin(); -} - -template -bool NodeView::IsRoot() const { - return GetStorage().GetParents().empty(); -} - -template -bool NodeView::IsLeaf() const { - auto children = GetChildren(); - return children.begin() == children.end(); -} - -template -void NodeView::AddParentEdge(Edge edge) const { - if constexpr (is_mutable) { - GetStorage().AddEdge(edge.GetClade(), edge.GetId(), false); - } -} - -template -void NodeView::AddChildEdge(Edge edge) const { - if constexpr (is_mutable) { - GetStorage().AddEdge(edge.GetClade(), edge.GetId(), true); - } -} - -template -void NodeView::RemoveParentEdge(Edge edge) const { - if constexpr (is_mutable) { - GetStorage().RemoveEdge(edge, false); - } -} - -template -const std::optional NodeView::GetSampleId() const { - return GetStorage().GetSampleId(); -} - -template -void NodeView::SetSampleId(std::optional sample_id) { - if constexpr (is_mutable) { - GetStorage().SetSampleId(sample_id); - } -} - -template class NodeView; -template class NodeView; diff --git a/src/node_storage.cpp b/src/node_storage.cpp index ef65fadb..fc633f0c 100644 --- a/src/node_storage.cpp +++ b/src/node_storage.cpp @@ -19,10 +19,12 @@ void NodeStorage::AddEdge(CladeIdx clade, EdgeId id, bool this_node_is_parent) { } } -const std::optional NodeStorage::GetSampleId() const { return sample_id_; } +const std::optional& NodeStorage::GetSampleId() const { + return sample_id_; +} -void NodeStorage::SetSampleId(std::optional sample_id) { - sample_id_ = sample_id; +void NodeStorage::SetSampleId(std::optional&& sample_id) { + sample_id_ = std::forward>(sample_id); } void NodeStorage::RemoveEdge(Edge edge, bool this_node_is_parent) { diff --git a/test/test_sample_tree.cpp b/test/test_sample_tree.cpp index 454d8742..02b739b3 100644 --- a/test/test_sample_tree.cpp +++ b/test/test_sample_tree.cpp @@ -1,3 +1,4 @@ +#include "mutation_annotated_dag.hpp" #include "subtree_weight.hpp" #include "parsimony_score.hpp" @@ -8,6 +9,8 @@ #include "dag_loader.hpp" +#include "tree_count.hpp" + static void test_sample_tree(MADAG& dag) { if (dag.GetEdgeMutations().empty()) { dag.RecomputeEdgeMutations(); @@ -16,8 +19,11 @@ static void test_sample_tree(MADAG& dag) { SubtreeWeight weight(dag); MADAG result = weight.SampleTree({}).first; - assert_true(result.GetDAG().IsTree(), "Tree"); + + SubtreeWeight tree_count{dag}; + MADAG result2 = tree_count.UniformSampleTree({}).first; + assert_true(result2.GetDAG().IsTree(), "Tree"); } static void test_sample_tree(std::string_view path) { diff --git a/tools/larch-usher.cpp b/tools/larch-usher.cpp index c6d587c9..aa5fed4e 100644 --- a/tools/larch-usher.cpp +++ b/tools/larch-usher.cpp @@ -19,10 +19,13 @@ #include "tree_count.hpp" #include "parsimony_score.hpp" #include "merge.hpp" +#include "benchmark.hpp" #include #include "../deps/usher/src/matOptimize/Profitable_Moves_Enumerators/Profitable_Moves_Enumerators.hpp" +#include + MADAG optimize_dag_direct(const MADAG& dag, Move_Found_Callback& callback); [[noreturn]] static void Usage() { std::cout << "Usage:\n"; @@ -33,7 +36,8 @@ MADAG optimize_dag_direct(const MADAG& dag, Move_Found_Callback& callback); std::cout << " -m,--matopt Path to matOptimize executable. Default: matOptimize\n"; std::cout << " -l,--logfile Name for logging csv file. Default: logfile.csv\n"; std::cout << " -c,--count Number of iterations. Default: 1\n"; - std::cout << " -r,--MAT-refseq-file Provide a path to a file containing a reference sequence\nif input points to MAT protobuf\n"; + std::cout << " -r,--MAT-refseq-file Provide a path to a file containing a " + "reference sequence\nif input points to MAT protobuf\n"; std::exit(EXIT_SUCCESS); } @@ -237,15 +241,16 @@ int main(int argc, char** argv) { MADAG input_dag; if (!refseq_path.empty()) { - // we should really take a fasta with one record, or at least remove - // newlines - std::string refseq; - std::fstream file; - file.open(refseq_path); - while (file >> refseq) {} - input_dag = LoadTreeFromProtobuf(input_dag_path, refseq); + // we should really take a fasta with one record, or at least remove + // newlines + std::string refseq; + std::fstream file; + file.open(refseq_path); + while (file >> refseq) { + } + input_dag = LoadTreeFromProtobuf(input_dag_path, refseq); } else { - input_dag = LoadDAGFromProtobuf(input_dag_path); + input_dag = LoadDAGFromProtobuf(input_dag_path); } Merge merge{input_dag.GetReferenceSequence()}; merge.AddDAGs({input_dag}); @@ -270,6 +275,9 @@ int main(int argc, char** argv) { }; logger(0); + CALLGRIND_START_INSTRUMENTATION; + Benchmark loop_time; + loop_time.start(); for (size_t i = 0; i < count; ++i) { std::cout << "############ Beginning optimize loop " << std::to_string(i) << " #######\n"; @@ -285,6 +293,9 @@ int main(int argc, char** argv) { logger(i + 1); } + loop_time.stop(); + std::cout << " Loop ended in " << loop_time.durationMs() << " ms.\n"; + CALLGRIND_STOP_INSTRUMENTATION; logfile.close(); StoreDAGToProtobuf(merge.GetResult().GetDAG(),