Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Library Fixes #1521

Merged
merged 23 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions lib/utils/include/utils/containers/find.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FIND_H

#include <algorithm>
#include <unordered_set>

namespace FlexFlow {

Expand All @@ -11,6 +12,12 @@ typename Container::const_iterator
return std::find(c.cbegin(), c.cend(), e);
}

template <typename V>
typename std::unordered_set<V>::const_iterator
find(std::unordered_set<V> const &c, V const &e) {
return c.find(e);
}

} // namespace FlexFlow

#endif
25 changes: 22 additions & 3 deletions lib/utils/include/utils/containers/zip.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H

#include "fmt/format.h"
#include "utils/exception.h"
#include <tuple>
#include <utility>
#include <vector>

namespace FlexFlow {

template <typename L, typename R>
std::vector<std::pair<L, R>> zip(std::vector<L> const &l,
std::vector<R> const &r) {
std::vector<std::pair<L, R>>
zip(std::vector<L> const &l, std::vector<R> const &r, bool strict = false) {
if (strict && l.size() != r.size()) {
throw mk_runtime_error(fmt::format("When strict = true, vector sizes must "

Check warning on line 16 in lib/utils/include/utils/containers/zip.h

View check run for this annotation

Codecov / codecov/patch

lib/utils/include/utils/containers/zip.h#L16

Added line #L16 was not covered by tests
"match. Got vectors of length {} and {}",
l.size(),
r.size()));

Check warning on line 19 in lib/utils/include/utils/containers/zip.h

View check run for this annotation

Codecov / codecov/patch

lib/utils/include/utils/containers/zip.h#L18-L19

Added lines #L18 - L19 were not covered by tests
}

std::vector<std::pair<L, R>> result;
for (int i = 0; i < std::min(l.size(), r.size()); i++) {
result.push_back(std::make_pair(l.at(i), r.at(i)));
Expand All @@ -20,7 +29,17 @@
template <typename A, typename B, typename C>
std::vector<std::tuple<A, B, C>> zip(std::vector<A> const &a,
std::vector<B> const &b,
std::vector<C> const &c) {
std::vector<C> const &c,
bool strict = false) {
if (strict && (a.size() != b.size() || b.size() != c.size())) {
throw std::runtime_error(

Check warning on line 35 in lib/utils/include/utils/containers/zip.h

View check run for this annotation

Codecov / codecov/patch

lib/utils/include/utils/containers/zip.h#L35

Added line #L35 was not covered by tests
fmt::format("When strict = true, vectors sizes must match. Got vectors "
"of length {}, {} and {}",
a.size(),
b.size(),
c.size()));

Check warning on line 40 in lib/utils/include/utils/containers/zip.h

View check run for this annotation

Codecov / codecov/patch

lib/utils/include/utils/containers/zip.h#L38-L40

Added lines #L38 - L40 were not covered by tests
}

std::vector<std::tuple<A, B, C>> result;
for (int i = 0; i < std::min({a.size(), b.size(), c.size()}); i++) {
result.push_back(std::make_tuple(a.at(i), b.at(i), c.at(i)));
Expand Down
260 changes: 144 additions & 116 deletions lib/utils/include/utils/graph/README.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions lib/utils/include/utils/graph/algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ std::unordered_set<Node> get_neighbors(DiGraphView const &, Node const &);
// &);

// return the set of nodes without incoming edges
std::unordered_set<Node> get_sources(DiGraphView const &);
std::unordered_set<Node> get_initial_nodes(DiGraphView const &);

// return the set of nodes without outgoing edges
std::unordered_set<Node> get_sinks(DiGraphView const &);
std::unordered_set<Node> get_terminal_nodes(DiGraphView const &);

// std::unordered_set<Node> get_closed_sources(OpenMultiDiGraphView const &g);
// std::unordered_set<Node> get_closed_sinks(OpenMultiDiGraphView const &g);
Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ struct DataflowGraph : virtual public DataflowGraphView {
private:
IDataflowGraph &get_interface();
IDataflowGraph const &get_interface() const;

friend struct GraphInternal;
};

} // namespace FlexFlow
Expand Down
4 changes: 2 additions & 2 deletions lib/utils/include/utils/graph/digraph/algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
namespace FlexFlow {

std::unordered_set<DirectedEdge> get_edges(DiGraphView const &);
std::unordered_set<Node> get_sources(DiGraphView const &);
std::unordered_set<Node> get_sinks(DiGraphView const &);
std::unordered_set<Node> get_initial_nodes(DiGraphView const &);
std::unordered_set<Node> get_terminal_nodes(DiGraphView const &);

} // namespace FlexFlow

Expand Down
15 changes: 15 additions & 0 deletions lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,22 @@

namespace FlexFlow {

/**
* @brief See https://en.wikipedia.org/wiki/Dominator_(graph_theory)
*
* @note By definition, the root node dominates every node and every node
* dominates itself.
*
*/
std::unordered_set<Node> get_dominators(DiGraphView const &, Node const &);

/**
* @brief Returns the intersection of the dominators of the given set of nodes.
* @note This is conceptually equivalent to merging the given set of nodes and
* then finding the set of dominators of the new merged node (where merged means
* that all edges belonging to the set of nodes now pass through a single
* unified node).
*/
std::unordered_set<Node> get_dominators(DiGraphView const &,
std::unordered_set<Node> const &);

Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/digraph/digraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ struct DiGraph : virtual DiGraphView {
private:
IDiGraph &get_ptr();
IDiGraph const &get_ptr() const;

friend struct GraphInternal;
};
CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraph);

Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/digraph/digraph_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ struct DiGraphView : virtual public GraphView {

private:
IDiGraphView const &get_ptr() const;

friend struct GraphInternal;
};
CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraphView);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ namespace FlexFlow {
std::unordered_set<MultiDiEdge> get_incoming_edges(MultiDiGraphView const &,
Node const &);

std::unordered_map<Node, std::unordered_set<MultiDiEdge>>
get_incoming_edges(MultiDiGraphView const &g,
std::unordered_set<Node> const &nodes);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_ALGORITHMS_GET_OUTGOING_EDGES_H

#include "utils/graph/multidigraph/multidigraph_view.h"
#include <unordered_set>

namespace FlexFlow {

std::unordered_set<MultiDiEdge> get_outgoing_edges(MultiDiGraphView const &,
Node const &);

std::unordered_map<Node, std::unordered_set<MultiDiEdge>>
get_outgoing_edges(MultiDiGraphView const &g,
std::unordered_set<Node> const &ns);

} // namespace FlexFlow

#endif
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/multidigraph/multidigraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ struct MultiDiGraph : virtual public MultiDiGraphView {
private:
IMultiDiGraph &get_interface();
IMultiDiGraph const &get_interface() const;

friend struct GraphInternal;
};

} // namespace FlexFlow
Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/node/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ struct Graph : virtual GraphView {
private:
IGraph const &get_ptr() const;
IGraph &get_ptr();

friend struct GraphInternal;
};

} // namespace FlexFlow
Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/node/graph_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ struct GraphView {
GraphView();
cow_ptr_t<IGraphView> ptr;
GraphView(cow_ptr_t<IGraphView> ptr);

friend struct GraphInternal;
};

} // namespace FlexFlow
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
namespace = "FlexFlow"
name = "ExtendedParallelReduction"
features = [
"eq",
"hash",
"fmt",
]

includes = [
"utils/graph/multidigraph/multidiedge.dtg.h",
"<unordered_set>"
]

src_includes = [
"utils/hash/unordered_set.h",
"utils/fmt/unordered_set.h",
]

[[fields]]
name = "edges"
type = "std::unordered_set<::FlexFlow::MultiDiEdge>"
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
namespace = "FlexFlow"
name = "ExtendedSeriesReduction"
features = [
"eq",
"hash",
"fmt",
]

includes = [
"utils/graph/multidigraph/multidiedge.dtg.h",
"<vector>"
]

src_includes = [
"utils/hash/vector.h",
"utils/fmt/vector.h",
]

[[fields]]
name = "edges"
type = "std::vector<::FlexFlow::MultiDiEdge>"
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
#include "utils/graph/digraph/digraph.h"
#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h"
#include "utils/optional.h"
#include <variant>
#include <vector>

namespace FlexFlow {

Expand Down
22 changes: 22 additions & 0 deletions lib/utils/include/utils/graph/series_parallel/parallel_reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,40 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_PARALLEL_REDUCTION_H

#include "utils/graph/multidigraph/multidigraph.h"
#include "utils/graph/series_parallel/extended_parallel_reduction.dtg.h"
#include "utils/graph/series_parallel/parallel_reduction.dtg.h"
#include <optional>
#include <unordered_set>

namespace FlexFlow {

ParallelReduction make_parallel_reduction(MultiDiEdge const &,
MultiDiEdge const &);

std::optional<ParallelReduction>
find_parallel_reduction(MultiDiGraphView const &);

/**
* @brief Finds all ExtendedParallelReduction for a given MultiDiGraph
* @details An ExtendedParallelReduction is a unordered collection of
* `MultiDiEdge`s such that they share a common source and destination node.
*/
std::unordered_set<ExtendedParallelReduction>
find_all_extended_parallel_reductions(MultiDiGraphView const &);

MultiDiEdge apply_parallel_reduction(MultiDiGraph &, ParallelReduction const &);

/**
* @brief Applies a given ExtendedParallelReduction in place to a given
* MultiDiGraph
* @details The reduction removes all but one `MultiDiEdge`, so that the source,
* destination nodes associated with the reduction become connected by a single
* edge.
*/
MultiDiEdge
apply_extended_parallel_reduction(MultiDiGraph &,
ExtendedParallelReduction const &);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,25 @@ std::unordered_multiset<Node> get_nodes(SeriesSplit const &);
std::unordered_multiset<Node> get_nodes(ParallelSplit const &);
std::unordered_multiset<Node> get_nodes(Node const &);

bool is_empty(Node const &node);
bool is_empty(SeriesSplit const &serial);
bool is_empty(ParallelSplit const &parallel);
bool is_empty(SeriesParallelDecomposition const &sp);

bool has_no_duplicate_nodes(SeriesParallelDecomposition const &sp);

SeriesParallelDecomposition delete_node(SeriesParallelDecomposition sp,
Node const &node);

// duplicate nodes within `sp` are counted multiple times
size_t num_nodes(SeriesParallelDecomposition const &sp);

SeriesParallelDecomposition series_composition(
std::vector<SeriesParallelDecomposition> const &sp_compositions);
SeriesParallelDecomposition parallel_composition(
std::unordered_multiset<SeriesParallelDecomposition> const
&sp_compositions);

} // namespace FlexFlow

#endif
45 changes: 45 additions & 0 deletions lib/utils/include/utils/graph/series_parallel/series_reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

#include "utils/graph/multidigraph/multidiedge.dtg.h"
#include "utils/graph/multidigraph/multidigraph.h"
#include "utils/graph/series_parallel/extended_series_reduction.dtg.h"
#include "utils/graph/series_parallel/series_reduction.dtg.h"
#include "utils/hash/vector.h"

namespace FlexFlow {

Expand All @@ -14,8 +16,51 @@ Node get_center_node(MultiDiGraphView const &, SeriesReduction const &);
SeriesReduction make_series_reduction(MultiDiEdge const &, MultiDiEdge const &);
std::optional<SeriesReduction> find_series_reduction(MultiDiGraphView const &);

/**
* @brief Finds all the ExtendedSeriesReduction structures in a given graph.
*
* @details An `ExtendedSeriesReduction` is an ordered collection of
* `MultiDiEdges` such that:
* - The destination node of the nth edge is the same as the source node of the
* (n+1)th edge.
* - Such a node (intermediate node) has exactly two edges: one incoming (nth
* edge) and one outgoing ((n+1)th edge).
*
* For example, in the following graph:
*
* A -> B -> D -> E
* \ /
* -> C ->
*
* We have that [(A,B), (B,D), (D,E)] and [(A,C), (C,E)] both constitute
* `ExtendedSeriesReduction`.
*/
std::unordered_set<ExtendedSeriesReduction>
find_all_extended_series_reductions(MultiDiGraphView const &g);

MultiDiEdge apply_series_reduction(MultiDiGraph &, SeriesReduction const &);

/**
* @brief Applies a given ExtendedSeriesReduction in-place to a given graph.
*
* For example, in the following graph:
*
* A -> B -> D -> E
* \ /
* -> C ->
*
* Given the ExtendedSeriesReduction [(A,B), (B,D), (D,E)], the intermediate
*nodes B, D, will be deleted, and the resulting graph will be:
*
* A ----> E
* \ /
* -> C ->
*
**/
MultiDiEdge
apply_extended_series_reduction(MultiDiGraph &g,
ExtendedSeriesReduction const &reduction);

} // namespace FlexFlow

#endif
Loading
Loading