Skip to content

Commit

Permalink
Re-factor progress to avoid the use of a global.
Browse files Browse the repository at this point in the history
  • Loading branch information
mihaeladuta committed Nov 22, 2024
1 parent a2f1871 commit 8bd5812
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
8 changes: 4 additions & 4 deletions l2gv2/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _fennel_clustering(
deltas = -alpha * gamma * (partition_sizes ** (gamma - 1))

with numba.objmode:
progress.reset(num_nodes)
pbar = progress.reset(num_nodes)

for it in range(num_iters):
not_converged = 0
Expand Down Expand Up @@ -227,17 +227,17 @@ def _fennel_clustering(
if i % 10000 == 0 and i > 0:
progress_it = i
with numba.objmode:
progress.update(10000)
progress.update(pbar, 10000)
with numba.objmode:
progress.update(num_nodes - progress_it)
progress.update(pbar, num_nodes - progress_it)

print("iteration: " + str(it) + ", not converged: " + str(not_converged))

if not_converged == 0:
print(f"converged after {it} iterations.")
break
with numba.objmode:
progress.close()
progress.close(pbar)

return clusters
# pylint: enable=too-many-branches
Expand Down
12 changes: 6 additions & 6 deletions l2gv2/network/npgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,14 +338,14 @@ def _memmap_degree(edge_index, num_nodes):
degree = np.zeros(num_nodes, dtype=np.int64)
with numba.objmode:
print("computing degrees")
progress.reset(edge_index.shape[1])
pbar = progress.reset(edge_index.shape[1])
for it, source in enumerate(edge_index[0]):
degree[source] += 1
if it % 1000000 == 0 and it > 0:
with numba.objmode:
progress.update(1000000)
progress.update(pbar, 1000000)
with numba.objmode:
progress.close()
progress.close(pbar)
return degree


Expand Down Expand Up @@ -462,7 +462,7 @@ def partition_graph_edges(self, partition, self_loops):
num_edges = self.num_edges
with numba.objmode:
print("finding partition edges")
progress.reset(num_edges)
pbar = progress.reset(num_edges)
num_clusters = partition.max() + 1
edge_counts = np.zeros((num_clusters, num_clusters), dtype=np.int64)
for i, (source, target) in enumerate(self.edge_index.T):
Expand All @@ -472,9 +472,9 @@ def partition_graph_edges(self, partition, self_loops):
edge_counts[source, target] += 1
if i % 1000000 == 0 and i > 0:
with numba.objmode:
progress.update(1000000)
progress.update(pbar, 1000000)
with numba.objmode:
progress.close()
progress.close(pbar)
index = np.nonzero(edge_counts)
partition_edges = np.vstack(index)
weights = np.empty((len(index[0]),), dtype=np.int64)
Expand Down
7 changes: 3 additions & 4 deletions l2gv2/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,15 @@

def reset(total):
""" TODO: docstring for reset"""
global pbar
pbar = tqdm(total=total)
return tqdm(total=total)


def update(iterations):
def update(pbar, iterations):
""" TODO: docstring for update"""
pbar.update(iterations)


def close():
def close(pbar):
""" TODO: docstring for close"""
pbar.update(pbar.total - pbar.n)
pbar.close()

0 comments on commit 8bd5812

Please sign in to comment.