Skip to content

Commit

Permalink
Can add only nodes to nodelist. (#2369)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgamiye authored Jun 21, 2023
1 parent 7197b37 commit a85f6ee
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 1 deletion.
35 changes: 34 additions & 1 deletion share/lib/python/neuron/rxd/nodelist.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from .rxdException import RxDException
from .node import Node
import types
from collections import abc


class NodeList(list):
def __init__(self, items):
"""Constructs a NodeList from items, a python iterable containing Node objects."""
list.__init__(self, items)
if isinstance(items, abc.Generator) or isinstance(items, abc.Iterator):
items = list(items)

if items == [] or all(isinstance(item, Node) for item in items):
list.__init__(self, items)
else:
raise TypeError("Items must be nodes.")

def __call__(self, restriction):
"""returns a sub-NodeList consisting of nodes satisfying restriction"""
Expand All @@ -16,6 +25,30 @@ def __getitem__(self, key):
else:
return list.__getitem__(self, key)

def __setitem__(self, index, value):
if not isinstance(value, Node):
raise TypeError("Only assign a node to the list")
super().__setitem__(index, value)

def append(self, items):
if not isinstance(items, Node):
raise TypeError("The append item must be a Node.")
super().append(items)

def extend(self, items):
if isinstance(items, abc.Generator) or isinstance(items, abc.Iterator):
items = list(items)

for item in items:
if not isinstance(item, Node):
raise TypeError("The extended items must all be Nodes.")
super().extend(items)

def insert(self, position, items):
if not isinstance(items, Node):
raise TypeError("The item inserted must be a Node.")
super().insert(position, items)

@property
def value(self):
# TODO: change this when not everything is a concentration
Expand Down
90 changes: 90 additions & 0 deletions test/rxd/test_nodelist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
def test_only_nodes(neuron_instance):
"""Test to make sure node lists only contain nodes"""

h, rxd, data, save_path = neuron_instance

dend = h.Section("dend")
r = rxd.Region(h.allsec())
hydrogen = rxd.Species(r, initial=1)
water = rxd.Species(r, initial=1)

h.finitialize(-65)

nodelist = hydrogen.nodes

# test that should not work, so an append that succeeds is an error

try:
nodelist.append(water.nodes) # append nodelist
raise Exception("should not get here")
except TypeError:
...

try:
nodelist.extend([1, 2, 3, water.nodes[0]]) # extend with non-nodes
raise Exception("should not get here")
except TypeError:
...

try:
nodelist[0] = 17
raise Exception("should not get here")
except TypeError:
...

try:
nl = rxd.nodelist.NodeList(
[1, 2, 3, water.nodes[0]]
) # create NodeList with non-nodes
raise Exception("should not get here")
except TypeError:
...

try:
nodelist.insert(1, "llama") # insert non-node into nodelist
raise Exception("should not get here")
except TypeError:
...

# test that should work, so getting in the except is an error
try:
nodelist.append(water.nodes[0]) # append node
except TypeError:
raise Exception("should not get here")

try:
original_length = len(nodelist) # extend nodes
nodelist.extend(item for item in water.nodes)
assert len(nodelist) == original_length + len(water.nodes)
except TypeError:
raise Exception("should not get here")

try:
nodelist[0] = water.nodes[0]
except TypeError:
raise Exception("should not get here")

try:
nl = rxd.nodelist.NodeList(
[water.nodes[0], water.nodes[0]]
) # create nodelist with nodes
except TypeError:
raise Exception("should not get here")

try:
nodelist.insert(1, water.nodes[0]) # insert node into nodelist
except TypeError:
raise Exception("should not get here")

try:
nl = rxd.nodelist.NodeList([]) # create empty nodelist
except TypeError:
raise Exception("should not get here")

try:
nl = rxd.nodelist.NodeList(
item for item in [water.nodes[0], water.nodes[0]]
) # create nodelist with nodes generator
assert len(nl) == 2
except TypeError:
raise Exception("should not get here")

0 comments on commit a85f6ee

Please sign in to comment.