Skip to content

Commit

Permalink
🎨 #25 added optional parameter to filter random walks based on an ent…
Browse files Browse the repository at this point in the history
…ity list
  • Loading branch information
ferzcam committed Jul 27, 2022
1 parent 50ac7f2 commit d732a9a
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 17 deletions.
24 changes: 21 additions & 3 deletions gateway/src/main/scala/org/mowl/Walking/DeepWalk.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import collection.JavaConverters._
import java.io._
import java.util.{HashMap, ArrayList}
import scala.collection.mutable.{MutableList, ListBuffer, Map, ArrayBuffer}
import scala.collection.immutable.HashSet
import util.control.Breaks._
import java.util.concurrent.{ExecutorService, Executors}
import scala.concurrent.ExecutionContext.Implicits.global
Expand All @@ -19,7 +20,9 @@ class DeepWalk (
var walkLength: Int,
var alpha: Float,
var workers: Int,
var outfile: String) {
var outfile: String,
var nodesOfInterest: ArrayList[String]
) {


val edgesSc = edges.asScala.map(x => (x.src, x.rel, x.dst))
Expand All @@ -36,6 +39,8 @@ class DeepWalk (
val rand = scala.util.Random
val (pathsPerWorker, newWorkers) = numPathsPerWorker()

val nodesOfInterestIdx = HashSet() ++ nodesOfInterest.asScala.map(mapEntsIdx(_)).toSet

private[this] val lock = new Object()

val walksFile = new File(outfile)
Expand Down Expand Up @@ -145,9 +150,22 @@ class DeepWalk (
}

val toWrite = walk.filter(_ != -1).map(x => mapIdxEnts(x)).mkString(" ") + "\n"
lock.synchronized {
bw.write(toWrite)

if (nodesOfInterest.size > 0){
val walkSet = HashSet() ++ walk.toSet
val intersection = walkSet & nodesOfInterestIdx

if (intersection.size > 0){


lock.synchronized {
bw.write(toWrite)
}
}
}else{
lock.synchronized {
bw.write(toWrite)
}
}

}
Expand Down
25 changes: 21 additions & 4 deletions gateway/src/main/scala/org/mowl/Walking/Node2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import collection.JavaConverters._
import java.io._
import java.util.{ArrayList}
import scala.collection.mutable.{MutableList, ListBuffer, Stack, Map, HashMap, ArrayBuffer}
import scala.collection.immutable.HashSet
import util.control.Breaks._
import java.util.concurrent.{ExecutorService, Executors}
import scala.concurrent.ExecutionContext.Implicits.global
Expand All @@ -20,7 +21,9 @@ class Node2Vec (
var p: Float,
var q: Float,
var workers: Int,
var outfile: String) {
var outfile: String,
var nodesOfInterest: ArrayList[String]
) {


val edgesSc = edges.asScala.map(x => (x.src, x.rel, x.dst, x.weight))
Expand All @@ -42,6 +45,8 @@ class Node2Vec (
var aliasNodes = Map[Int, (Array[Int], Array[Float])]()
var aliasEdges = Map[(Int, Int), (Array[Int], Array[Float])]()

val nodesOfInterestIdx = HashSet() ++ nodesOfInterest.asScala.map(mapEntsIdx(_)).toSet

private[this] val lock = new Object()

val walksFile = new File(outfile)
Expand Down Expand Up @@ -196,14 +201,26 @@ class Node2Vec (
}

val toWrite = walk.filter(_ != -1).map(x => mapIdxEnts(x)).mkString(" ") + "\n"
lock.synchronized {
bw.write(toWrite)

if (nodesOfInterest.size > 0){
val walkSet = HashSet() ++ walk.toSet
val intersection = walkSet & nodesOfInterestIdx

if (intersection.size > 0){

lock.synchronized {
bw.write(toWrite)
}
}
}else{
lock.synchronized {
bw.write(toWrite)
}
}

}



def getAliasEdge(src: Int, dst: Int) = {
if (graph.contains(dst)){
val dstNbrs = graph(dst).map(_._2)
Expand Down
22 changes: 19 additions & 3 deletions mowl/walking/deepwalk/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from java.util import HashMap
from java.util import ArrayList
from org.mowl import Edge
from mowl.projection.edge import Edge as PyEdge
from deprecated.sphinx import versionchanged

logging.basicConfig(level=logging.INFO)


Expand Down Expand Up @@ -36,13 +39,26 @@ def __init__(self,

self.alpha = alpha

def walk(self, edges):

@versionchanged(version = "0.1.0", reason = "The method now can accept a list of entities to focus on when generating the random walks.")
def walk(self, edges, nodes_of_interest = None):
if nodes_of_interest is None:
nodes_of_interest = ArrayList()
else:
all_nodes, _ = PyEdge.getEntitiesAndRelations(edges)
all_nodes = set(all_nodes)
python_nodes = nodes_of_interest[:]
nodes_of_interest = ArrayList()
for node in python_nodes:
if node in all_nodes:
nodes_of_interest.add(node)
else:
logging.info(f"Node {node} does not exist in graph. Ignoring it.")

edgesJ = ArrayList()
for edge in edges:
newEdge = Edge(edge.src(), edge.rel(), edge.dst())
edgesJ.add(newEdge)

walker = DW(edgesJ, self.num_walks, self.walk_length, self.alpha, self.workers, self.outfile)
walker = DW(edgesJ, self.num_walks, self.walk_length, self.alpha, self.workers, self.outfile, nodes_of_interest)

walker.walk()
24 changes: 19 additions & 5 deletions mowl/walking/node2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from java.util import ArrayList
from org.mowl import Edge
from org.mowl.Walking import Node2Vec as N2V
from mowl.projection.edge import Edge as PyEdge
from deprecated.sphinx import versionchanged

logging.basicConfig(level=logging.INFO)

class Node2Vec(WalkingModel):

Expand Down Expand Up @@ -33,16 +37,26 @@ def __init__(self,
self.p = p
self.q = q

def walk(self, edges):

def walk(self, edges, nodes_of_interest = None):
if nodes_of_interest is None:
nodes_of_interest = ArrayList()
else:
all_nodes, _ = PyEdge.getEntitiesAndRelations(edges)
all_nodes = set(all_nodes)
python_nodes = nodes_of_interest[:]
nodes_of_interest = ArrayList()
for node in python_nodes:
if node in all_nodes:
nodes_of_interest.add(node)
else:
logging.info(f"Node {node} does not exist in graph. Ignoring it.")

edgesJ = ArrayList()

for edge in edges:
newEdge = Edge(edge.src(), edge.rel(), edge.dst(), edge.weight())

edgesJ.add(newEdge)

walker = N2V(edgesJ, self.num_walks, self.walk_length, self.p, self.q, self.workers, self.outfile)
walker = N2V(edgesJ, self.num_walks, self.walk_length, self.p, self.q, self.workers, self.outfile, nodes_of_interest)

walker.walk()

Expand Down
10 changes: 8 additions & 2 deletions mowl/walking/walking.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from deprecated.sphinx import versionchanged

class WalkingModel():

Expand All @@ -14,14 +15,19 @@ def __init__(self, num_walks, walk_length, outfile, workers=1):
self.walk_length = walk_length
self.workers = workers
self.outfile = outfile


# Abstract methods
def walk(self, edges):
@versionchanged(version = "0.1.0", reason = "The method now can accept a list of entities to focus on when generating the random walks.")
def walk(self, edges, nodes_of_interest = None):

'''
This method will generate the walks.
This method will generate random walks from a graph in the form of edgelist.
:param edges: List of edges
:type edges: :class:`mowl.graph.edge.Edge`
:param nodes_of_interest: List of entity names to filter the generated walks. If a walk contains at least one word of interest, it will be saved into disk, otherwise it will be ignored. If no list is input, all the nodes will be considered. Defaults to ``None``
:type nodes_of_interest: list, optional
'''

raise NotImplementedError()
Expand Down

0 comments on commit d732a9a

Please sign in to comment.