Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniform sampling #31

Merged
merged 3 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ add_library(larch
src/mutation_annotated_dag.cpp
src/node_label.cpp
src/node_storage.cpp
src/node.cpp
src/post_order_iterator.cpp
src/pre_order_iterator.cpp)
larch_compile_opts(larch)
Expand Down
80 changes: 80 additions & 0 deletions include/impl/node_impl.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,86 @@
// Functions defined here are documented where declared in `include/node.hpp`
#include <range/v3/view/join.hpp>

template <typename T>
NodeView<T>::NodeView(T dag, NodeId id) : dag_{dag}, id_{id} {
static_assert(std::is_same_v<T, DAG&> or std::is_same_v<T, const DAG&>);
Assert(id.value != NoId);
Assert(id.value < dag_.nodes_.size());
}

template <typename T>
NodeView<T>::operator Node() const {
return {dag_, id_};
}

template <typename T>
NodeView<T>::operator NodeId() const {
return id_;
}

template <typename T>
T NodeView<T>::GetDAG() const {
return dag_;
}

template <typename T>
NodeId NodeView<T>::GetId() const {
return id_;
}

template <typename T>
typename NodeView<T>::EdgeType NodeView<T>::GetSingleParent() const {
Assert(GetParents().size() == 1);
return *GetParents().begin();
}

template <typename T>
bool NodeView<T>::IsRoot() const {
return GetStorage().GetParents().empty();
}

template <typename T>
bool NodeView<T>::IsLeaf() const {
if (GetClades().empty()) {
return true;
}
auto children = GetChildren();
return children.begin() == children.end();
}

template <typename T>
void NodeView<T>::AddParentEdge(Edge edge) const {
if constexpr (is_mutable) {
GetStorage().AddEdge(edge.GetClade(), edge.GetId(), false);
}
}

template <typename T>
void NodeView<T>::AddChildEdge(Edge edge) const {
if constexpr (is_mutable) {
GetStorage().AddEdge(edge.GetClade(), edge.GetId(), true);
}
}

template <typename T>
void NodeView<T>::RemoveParentEdge(Edge edge) const {
if constexpr (is_mutable) {
GetStorage().RemoveEdge(edge, false);
}
}

template <typename T>
const std::optional<std::string>& NodeView<T>::GetSampleId() const {
return GetStorage().GetSampleId();
}

template <typename T>
void NodeView<T>::SetSampleId(std::optional<std::string>&& sample_id) {
if constexpr (is_mutable) {
GetStorage().SetSampleId(std::forward<std::optional<std::string>>(sample_id));
}
}

template <typename T>
auto NodeView<T>::GetParents() const {
return GetStorage().GetParents() | Transform::ToEdges(dag_);
Expand Down
67 changes: 51 additions & 16 deletions include/impl/subtree_weight_impl.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <algorithm>
#include <type_traits>

template <typename WeightOps>
SubtreeWeight<WeightOps>::SubtreeWeight(const MADAG& dag)
Expand Down Expand Up @@ -54,20 +55,35 @@ MADAG SubtreeWeight<WeightOps>::TrimToMinWeight(WeightOps&& weight_ops) {
template <typename WeightOps>
std::pair<MADAG, std::vector<NodeId>> SubtreeWeight<WeightOps>::SampleTree(
WeightOps&& weight_ops) {
MADAG result{dag_.GetReferenceSequence()};
std::vector<NodeId> result_dag_ids;

ExtractTree(
dag_, dag_.GetDAG().GetRoot(), std::forward<WeightOps>(weight_ops),
[this](Node node, CladeIdx clade_idx) {
auto clade = node.GetClade(clade_idx);
Assert(not clade.empty());
std::uniform_int_distribution<size_t> distribuition{0, clade.size() - 1};
return clade.at(distribuition(random_generator_));
},
result, result_dag_ids);
return SampleTreeImpl(std::forward<WeightOps>(weight_ops), [](auto clade) {
return std::uniform_int_distribution<size_t>{0, clade.size() - 1};
});
}

return {std::move(result), std::move(result_dag_ids)};
struct TreeCount;
template <typename WeightOps>
std::pair<MADAG, std::vector<NodeId>> SubtreeWeight<WeightOps>::UniformSampleTree(
WeightOps&& weight_ops) {
static_assert(std::is_same_v<std::decay_t<WeightOps>, TreeCount>,
"UniformSampleTree needs TreeCount");
// Ensure cache is filled
ComputeWeightBelow(dag_.GetDAG().GetRoot(), std::forward<WeightOps>(weight_ops));
return SampleTreeImpl(
std::forward<WeightOps>(weight_ops), [this, &weight_ops](auto clade) {
std::vector<double> probabilities;
typename WeightOps::Weight sum{};
for (NodeId child : clade | Transform::GetChild()) {
sum += cached_weights_.at(child.value).value();
}
if (sum > 0) {
for (NodeId child : clade | Transform::GetChild()) {
probabilities.push_back(
static_cast<double>(cached_weights_.at(child.value).value() / sum));
}
}
return std::discrete_distribution<size_t>{probabilities.begin(),
probabilities.end()};
});
}

template <typename WeightOps>
Expand All @@ -94,6 +110,25 @@ typename WeightOps::Weight SubtreeWeight<WeightOps>::CladeWeight(
return clade_result.first;
}

template <typename WeightOps>
template <typename DistributionMaker>
std::pair<MADAG, std::vector<NodeId>> SubtreeWeight<WeightOps>::SampleTreeImpl(
WeightOps&& weight_ops, DistributionMaker&& distribution_maker) {
MADAG result{dag_.GetReferenceSequence()};
std::vector<NodeId> result_dag_ids;

ExtractTree(
dag_, dag_.GetDAG().GetRoot(), std::forward<WeightOps>(weight_ops),
[this, &distribution_maker](Node node, CladeIdx clade_idx) {
auto clade = node.GetClade(clade_idx);
Assert(not clade.empty());
return clade.at(distribution_maker(clade)(random_generator_));
},
result, result_dag_ids);

return {std::move(result), std::move(result_dag_ids)};
}

template <typename WeightOps>
template <typename EdgeSelector>
void SubtreeWeight<WeightOps>::ExtractTree(const MADAG& input_dag, Node node,
Expand Down Expand Up @@ -136,10 +171,10 @@ void SubtreeWeight<WeightOps>::ExtractTree(const MADAG& input_dag, Node node,

for (auto node : result.GetDAG().GetNodes()) {
size_t idx = node.GetId().value;
std::optional<std::string> old_sample_id =
const std::optional<std::string>& old_sample_id =
input_dag.GetDAG().GetNodes().at(idx).GetSampleId();
if (node.IsLeaf() and (bool) old_sample_id) {
node.SetSampleId(old_sample_id);
if (node.IsLeaf() and old_sample_id.has_value()) {
node.SetSampleId(std::optional<std::string>{old_sample_id});
}
}
}
29 changes: 24 additions & 5 deletions include/leaf_set.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include "common.hpp"
#include "compact_genome.hpp"

class CompactGenome;
class NodeLabel;

/**
Expand All @@ -27,9 +27,9 @@ class LeafSet {
LeafSet(Node node, const std::vector<NodeLabel>& labels,
std::vector<LeafSet>& computed_leafsets);

LeafSet(std::vector<std::vector<const CompactGenome*>>&& clades);
inline LeafSet(std::vector<std::vector<const CompactGenome*>>&& clades);

bool operator==(const LeafSet& rhs) const noexcept;
inline bool operator==(const LeafSet& rhs) const noexcept;

[[nodiscard]] size_t Hash() const noexcept;

Expand All @@ -43,8 +43,8 @@ class LeafSet {
const std::vector<std::vector<const CompactGenome*>>& GetClades() const;

private:
static size_t ComputeHash(
const std::vector<std::vector<const CompactGenome*>>& clades);
inline static size_t ComputeHash(
const std::vector<std::vector<const CompactGenome*>>& clades) noexcept;
};

template <>
Expand All @@ -58,3 +58,22 @@ struct std::equal_to<LeafSet> {
return lhs == rhs;
}
};

bool LeafSet::operator==(const LeafSet& rhs) const noexcept {
return clades_ == rhs.clades_;
}

LeafSet::LeafSet(std::vector<std::vector<const CompactGenome*>>&& clades)
: clades_{std::forward<std::vector<std::vector<const CompactGenome*>>>(clades)},
hash_{ComputeHash(clades_)} {}

size_t LeafSet::ComputeHash(
const std::vector<std::vector<const CompactGenome*>>& clades) noexcept {
size_t hash = 0;
for (auto& clade : clades) {
for (auto leaf : clade) {
hash = HashCombine(hash, leaf->Hash());
}
}
return hash;
}
28 changes: 14 additions & 14 deletions include/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ class NodeView {
constexpr static const bool is_mutable = std::is_same_v<T, DAG&>;
using NodeType = std::conditional_t<is_mutable, MutableNode, Node>;
using EdgeType = std::conditional_t<is_mutable, MutableEdge, Edge>;
NodeView(T dag, NodeId id);
operator Node() const;
operator NodeId() const;
inline NodeView(T dag, NodeId id);
inline operator Node() const;
inline operator NodeId() const;
/**
* Return DAG-like object containing this node
*/
T GetDAG() const;
NodeId GetId() const;
inline T GetDAG() const;
inline NodeId GetId() const;
/**
* Return a range containing parent Edge objects
*/
Expand All @@ -50,29 +50,29 @@ class NodeView {
/**
* Return the count of child clades
*/
size_t GetCladesCount() const;
inline size_t GetCladesCount() const;
/**
* Return a range containing child Edges
*/
auto GetChildren() const;
/**
* Return a single parent edge of this node
*/
EdgeType GetSingleParent() const;
inline EdgeType GetSingleParent() const;
/**
* Checks if node has no parents
*/
bool IsRoot() const;
inline bool IsRoot() const;
/**
* Checks if node has no children
*/
bool IsLeaf() const;
void AddParentEdge(Edge edge) const;
void AddChildEdge(Edge edge) const;
void RemoveParentEdge(Edge edge) const;
inline bool IsLeaf() const;
inline void AddParentEdge(Edge edge) const;
inline void AddChildEdge(Edge edge) const;
inline void RemoveParentEdge(Edge edge) const;

const std::optional<std::string> GetSampleId() const;
void SetSampleId(std::optional<std::string> sample_id);
inline const std::optional<std::string>& GetSampleId() const;
inline void SetSampleId(std::optional<std::string>&& sample_id);

private:
auto& GetStorage() const;
Expand Down
4 changes: 2 additions & 2 deletions include/node_storage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class NodeStorage {
*/
const std::vector<std::vector<EdgeId>>& GetClades() const;

const std::optional<std::string> GetSampleId() const;
void SetSampleId(std::optional<std::string> sample_id);
const std::optional<std::string>& GetSampleId() const;
void SetSampleId(std::optional<std::string>&& sample_id);

/**
* Remove all parent and child edges
Expand Down
7 changes: 7 additions & 0 deletions include/subtree_weight.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,17 @@ class SubtreeWeight {
[[nodiscard]] std::pair<MADAG, std::vector<NodeId>> SampleTree(
WeightOps&& weight_ops);

[[nodiscard]] std::pair<MADAG, std::vector<NodeId>> UniformSampleTree(
WeightOps&& weight_ops);

private:
template <typename CladeRange>
typename WeightOps::Weight CladeWeight(CladeRange&& clade, WeightOps&& weight_ops);

template <typename DistributionMaker>
[[nodiscard]] std::pair<MADAG, std::vector<NodeId>> SampleTreeImpl(
WeightOps&& weight_ops, DistributionMaker&& distribution_maker);

template <typename EdgeSelector>
void ExtractTree(const MADAG& input_dag, Node node, WeightOps&& weight_ops,
EdgeSelector&& edge_selector, MADAG& result,
Expand Down
2 changes: 1 addition & 1 deletion src/dag_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ MADAG LoadTreeFromProtobuf(std::string_view path, std::string_view reference_seq
result.BuildConnections();
for (auto node : result.GetDAG().GetNodes()) {
if (node.IsLeaf()) {
node.SetSampleId(seq_ids[node.GetId().value]);
node.SetSampleId(std::move(seq_ids[node.GetId().value]));
}
}

Expand Down
19 changes: 0 additions & 19 deletions src/leaf_set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <range/v3/range/conversion.hpp>

#include "dag.hpp"
#include "compact_genome.hpp"
#include "node_label.hpp"

const LeafSet* LeafSet::Empty() {
Expand Down Expand Up @@ -42,13 +41,6 @@ LeafSet::LeafSet(Node node, const std::vector<NodeLabel>& labels,
}()},
hash_{ComputeHash(clades_)} {}

LeafSet::LeafSet(std::vector<std::vector<const CompactGenome*>>&& clades)
: clades_{clades}, hash_{ComputeHash(clades_)} {}

bool LeafSet::operator==(const LeafSet& rhs) const noexcept {
return clades_ == rhs.clades_;
}

size_t LeafSet::Hash() const noexcept { return hash_; }

auto LeafSet::begin() const -> decltype(clades_.begin()) { return clades_.begin(); }
Expand All @@ -69,14 +61,3 @@ std::vector<const CompactGenome*> LeafSet::ToParentClade() const {
const std::vector<std::vector<const CompactGenome*>>& LeafSet::GetClades() const {
return clades_;
}

size_t LeafSet::ComputeHash(
const std::vector<std::vector<const CompactGenome*>>& clades) {
size_t hash = 0;
for (auto& clade : clades) {
for (auto leaf : clade) {
hash = HashCombine(hash, leaf->Hash());
}
}
return hash;
}
Loading