-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
Kruskal.cs
191 lines (166 loc) · 6.55 KB
/
Kruskal.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
using System;
using System.Collections.Generic;
using DataStructures.DisjointSet;
namespace Algorithms.Graph.MinimumSpanningTree;
/// <summary>
/// Algorithm to determine the minimum spanning forest of an undirected graph.
/// </summary>
/// <remarks>
/// Kruskal's algorithm is a greedy algorithm that can determine the
/// minimum spanning tree or minimum spanning forest of any undirected
/// graph. Unlike Prim's algorithm, Kruskal's algorithm will work on
/// graphs that are unconnected. This algorithm will always have a
/// running time of O(E log V) where E is the number of edges and V is
/// the number of vertices/nodes.
/// More information: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm .
/// Pseudocode and analysis: https://www.personal.kent.edu/~rmuhamma/Algorithms/MyAlgorithms/GraphAlgor/primAlgor.htm .
/// </remarks>
public static class Kruskal
{
/// <summary>
/// Determine the minimum spanning tree/forest of the given graph.
/// </summary>
/// <param name="adjacencyMatrix">Adjacency matrix representing the graph.</param>
/// <returns>Adjacency matrix of the minimum spanning tree/forest.</returns>
public static float[,] Solve(float[,] adjacencyMatrix)
{
ValidateGraph(adjacencyMatrix);
var numNodes = adjacencyMatrix.GetLength(0);
var set = new DisjointSet<int>();
var nodes = new Node<int>[numNodes];
var edgeWeightList = new List<float>();
var nodeConnectList = new List<(int, int)>();
// Add nodes to disjoint set
for (var i = 0; i < numNodes; i++)
{
nodes[i] = set.MakeSet(i);
}
// Create lists with edge weights and associated connectivity
for (var i = 0; i < numNodes - 1; i++)
{
for (var j = i + 1; j < numNodes; j++)
{
if (float.IsFinite(adjacencyMatrix[i, j]))
{
edgeWeightList.Add(adjacencyMatrix[i, j]);
nodeConnectList.Add((i, j));
}
}
}
var edges = Solve(set, nodes, edgeWeightList.ToArray(), nodeConnectList.ToArray());
// Initialize minimum spanning tree
var mst = new float[numNodes, numNodes];
for (var i = 0; i < numNodes; i++)
{
mst[i, i] = float.PositiveInfinity;
for (var j = i + 1; j < numNodes; j++)
{
mst[i, j] = float.PositiveInfinity;
mst[j, i] = float.PositiveInfinity;
}
}
foreach (var (node1, node2) in edges)
{
mst[node1, node2] = adjacencyMatrix[node1, node2];
mst[node2, node1] = adjacencyMatrix[node1, node2];
}
return mst;
}
/// <summary>
/// Determine the minimum spanning tree/forest of the given graph.
/// </summary>
/// <param name="adjacencyList">Adjacency list representing the graph.</param>
/// <returns>Adjacency list of the minimum spanning tree/forest.</returns>
public static Dictionary<int, float>[] Solve(Dictionary<int, float>[] adjacencyList)
{
ValidateGraph(adjacencyList);
var numNodes = adjacencyList.Length;
var set = new DisjointSet<int>();
var nodes = new Node<int>[numNodes];
var edgeWeightList = new List<float>();
var nodeConnectList = new List<(int, int)>();
// Add nodes to disjoint set and create list of edge weights and associated connectivity
for (var i = 0; i < numNodes; i++)
{
nodes[i] = set.MakeSet(i);
foreach(var (node, weight) in adjacencyList[i])
{
edgeWeightList.Add(weight);
nodeConnectList.Add((i, node));
}
}
var edges = Solve(set, nodes, edgeWeightList.ToArray(), nodeConnectList.ToArray());
// Create minimum spanning tree
var mst = new Dictionary<int, float>[numNodes];
for (var i = 0; i < numNodes; i++)
{
mst[i] = new Dictionary<int, float>();
}
foreach (var (node1, node2) in edges)
{
mst[node1].Add(node2, adjacencyList[node1][node2]);
mst[node2].Add(node1, adjacencyList[node1][node2]);
}
return mst;
}
/// <summary>
/// Ensure that the given graph is undirected.
/// </summary>
/// <param name="adj">Adjacency matrix of graph to check.</param>
private static void ValidateGraph(float[,] adj)
{
if (adj.GetLength(0) != adj.GetLength(1))
{
throw new ArgumentException("Matrix must be square!");
}
for (var i = 0; i < adj.GetLength(0) - 1; i++)
{
for (var j = i + 1; j < adj.GetLength(1); j++)
{
if (Math.Abs(adj[i, j] - adj[j, i]) > 1e-6)
{
throw new ArgumentException("Matrix must be symmetric!");
}
}
}
}
/// <summary>
/// Ensure that the given graph is undirected.
/// </summary>
/// <param name="adj">Adjacency list of graph to check.</param>
private static void ValidateGraph(Dictionary<int, float>[] adj)
{
for (var i = 0; i < adj.Length; i++)
{
foreach (var edge in adj[i])
{
if (!adj[edge.Key].ContainsKey(i) || Math.Abs(edge.Value - adj[edge.Key][i]) > 1e-6)
{
throw new ArgumentException("Graph must be undirected!");
}
}
}
}
/// <summary>
/// Determine the minimum spanning tree/forest.
/// </summary>
/// <param name="set">Disjoint set needed for set operations.</param>
/// <param name="nodes">List of nodes in disjoint set associated with each node.</param>
/// <param name="edgeWeights">Weights of each edge.</param>
/// <param name="connections">Nodes associated with each item in the <paramref name="edgeWeights"/> parameter.</param>
/// <returns>Array of edges in the minimum spanning tree/forest.</returns>
private static (int, int)[] Solve(DisjointSet<int> set, Node<int>[] nodes, float[] edgeWeights, (int, int)[] connections)
{
var edges = new List<(int, int)>();
Array.Sort(edgeWeights, connections);
foreach (var (node1, node2) in connections)
{
if (set.FindSet(nodes[node1]) != set.FindSet(nodes[node2]))
{
set.UnionSet(nodes[node1], nodes[node2]);
edges.Add((node1, node2));
}
}
return edges.ToArray();
}
}