Skip to content

Commit

Permalink
Merge pull request #23 from sifive/filter
Browse files Browse the repository at this point in the history
Add a Node.filter() method for generically filtering for nodes
  • Loading branch information
nategraff-sifive authored Nov 18, 2019
2 parents 893d9b4 + be2ffa7 commit 3abf78e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
25 changes: 18 additions & 7 deletions pydevicetree/ast/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ElementList = Iterable[Union['Node', Property, Directive]]

# Callback type signatures for Devicetree.match() and Devicetree.chosen()
MatchFunc = Callable[['Node'], bool]
MatchCallback = Optional[Callable[['Node'], None]]
ChosenCallback = Optional[Callable[[PropertyValues], None]]

Expand Down Expand Up @@ -229,6 +230,22 @@ def get_by_path(self, path: Union[Path, str]) -> Optional['Node']:
return matching_nodes[0]
return None

def filter(self, matchFunc: MatchFunc, cbFunc: MatchCallback = None) -> List['Node']:
"""Filter all child nodes by matchFunc
If cbFunc is provided, this method will iterate over the Nodes selected by matchFunc
and call cbFunc on each Node
Returns a list of all matching Nodes
"""
nodes = list(filter(matchFunc, self.child_nodes()))

if cbFunc is not None:
for n in nodes:
cbFunc(n)

return nodes

def match(self, compatible: Pattern, func: MatchCallback = None) -> List['Node']:
"""Get a node from the subtree by compatible string
Expand All @@ -242,13 +259,7 @@ def match_compat(node: Node) -> bool:
return any(regex.match(c) for c in compatibles)
return False

nodes = list(filter(match_compat, self.child_nodes()))

if func is not None:
for n in nodes:
func(n)

return nodes
return self.filter(match_compat, func)

def child_nodes(self) -> Iterable['Node']:
"""Get an iterable over all the nodes in the subtree"""
Expand Down
11 changes: 11 additions & 0 deletions tests/test_devicetree.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ def test_cells(self):
self.assertEqual(cpu1.address_cells(), 2)
self.assertEqual(cpu1.size_cells(), 1)

def test_filter(self):
def matchFunc(n):
return n.get_field("reg") == 1
def cbFunc(n):
self.assertEqual(n.get_field("compatible"), "riscv");

filtered_nodes = self.tree.filter(matchFunc, cbFunc)

# only the cpu1 node fulfills the matchFunc check
self.assertEqual(len(filtered_nodes), 1)

def test_match(self):
def func(cpu):
self.assertEqual(type(cpu), Node)
Expand Down

0 comments on commit 3abf78e

Please sign in to comment.