diff --git a/src/pecanpy/graph.py b/src/pecanpy/graph.py index 25254c05..05a5a539 100644 --- a/src/pecanpy/graph.py +++ b/src/pecanpy/graph.py @@ -77,6 +77,8 @@ class AdjlstGraph(BaseGraph): >>> indptr, indices, data = g.to_csr() # convert to csr >>> >>> dense_mat = g.to_dense() # convert to dense adjacency matrix + >>> + >>> g.save(edg_outpath) # save the graph to an edge list file """ @@ -222,6 +224,23 @@ def read(self, edg_fp, weighted, directed, delimiter="\t"): edge = self._read_edge_line(edge_line, weighted, delimiter) self.add_edge(*edge, directed) + def save(self, fp: str, unweighted: bool = False, delimiter: str = "\t"): + """Save AdjLst as an ``.edg`` edge list file. + + Args: + unweighted (bool): If set to True, only write two columns, + corresponding to the head and tail nodes of the edges, and + ignore the edge weights (default: :obj:`False`). + delimiter (str): Delimiter for separating fields. + + """ + with open(fp, "w") as f: + for h, t, w in self.edges_iter: + h, t = self.nodes[h], self.nodes[t] # convert index to node id + line = delimiter.join((h, t) if unweighted else (h, t, str(w))) + f.write(line) + f.write("\n") + def to_csr(self): """Construct compressed sparse row matrix.""" indptr = np.zeros(len(self.IDlst) + 1, dtype=np.uint32) diff --git a/test/test_graph.py b/test/test_graph.py index 0abc772b..80671f93 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -1,3 +1,6 @@ +import tempfile +import os +import shutil import unittest import numpy as np @@ -110,6 +113,48 @@ def test_edges2(self): ], ) + def test_save(self): + self.g = AdjlstGraph.from_mat(MAT, IDS) + + expected_results = { + (False, "\t"): [ + "a\tb\t1.0\n", + "a\tc\t1.0\n", + "b\ta\t1.0\n", + "c\ta\t1.0\n", + ], + (True, "\t"): [ + "a\tb\n", + "a\tc\n", + "b\ta\n", + "c\ta\n", + ], + (False, ","): [ + "a,b,1.0\n", + "a,c,1.0\n", + "b,a,1.0\n", + "c,a,1.0\n", + ], + (True, ","): [ + "a,b\n", + "a,c\n", + "b,a\n", + "c,a\n", + ], + } + + tmpdir = tempfile.mkdtemp() + tmpfp = os.path.join(tmpdir, "test.edg") + for unweighted in True, False: + for delimiter in ["\t", ","]: + self.g.save(tmpfp, unweighted=unweighted, delimiter=delimiter) + + with open(tmpfp, "r") as f: + expected_result = expected_results[(unweighted, delimiter)] + for line, expected_line in zip(f, expected_result): + self.assertEqual(line, expected_line) + shutil.rmtree(tmpdir) + class TestSparseGraph(unittest.TestCase): def tearDown(self):