Skip to content

Commit

Permalink
Separate Connections
Browse files Browse the repository at this point in the history
  • Loading branch information
david-ragazzi committed Apr 26, 2015
1 parent 98657cd commit fc89181
Show file tree
Hide file tree
Showing 6 changed files with 510 additions and 450 deletions.
283 changes: 283 additions & 0 deletions nupic/research/connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
# ----------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2014, Numenta, Inc. Unless you have an agreement
# with Numenta, Inc., for a separate license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# ----------------------------------------------------------------------

from collections import defaultdict, namedtuple



class Connections(object):
"""
Class to hold data representing the connectivity of a collection of
elements (cells or columns).
"""

SynapseData = namedtuple("SynapseData", ["segment",
"presynapticCell",
"permanence"])

def __init__(self, numElements):
"""
@param numElements (int) Number of elements in collection. An element can be a
cell (in case of distal connections) or a column (in case of proximal connections)
"""

# Save member variables
self.numElements = numElements

# Mappings
self._segments = dict()
self._synapses = dict()

# Indexes into the mappings (for performance)
self._segmentsForElement = dict()
self._synapsesForSegment = dict()
self._synapsesForPresynapticCell = defaultdict(dict)

# Index of the next segment to be created
self._nextSegmentIdx = 0
# Index of the next synapse to be created
self._nextSynapseIdx = 0


def elementForSegment(self, segment):
"""
Returns the cell or column (i.e. element) that a segment belongs to.
@param segment (int) Segment index
@return (int) Cell or Column index
"""
return self._segments[segment]


def segmentsForElement(self, element):
"""
Returns the segments that belong to a cell or column (i.e. element).
@param element (int) Cell or Column index
@return (set) Segment indices
"""
self._validateElement(element)

if not element in self._segmentsForElement:
return set()

return self._segmentsForElement[element]


def dataForSynapse(self, synapse):
"""
Returns the data for a synapse.
@param synapse (int) Synapse index
@return (SynapseData) Synapse data
"""
self._validateSynapse(synapse)

return self._synapses[synapse]


def synapsesForSegment(self, segment):
"""
Returns the synapses on a segment.
@param segment (int) Segment index
@return (set) Synapse indices
"""
self._validateSegment(segment)

if not segment in self._synapsesForSegment:
return set()

return self._synapsesForSegment[segment]


def synapsesForPresynapticCell(self, presynapticCell):
"""
Returns the synapses for the source cell that they synapse on.
@param presynapticCell (int) Source cell index
@return (set) Synapse indices
"""
return self._synapsesForPresynapticCell[presynapticCell]


def createSegment(self, element):
"""
Adds a new segment on a cell or column (i.e. element).
Note that while a cell can have one or more segments (making distal connections to lateral cells),
a column have only a single segment (making proximal connections to input cells).
@param element (int) Cell or Column index
@return (int) New segment index
"""
self._validateElement(element)

# Add data
segment = self._nextSegmentIdx
self._segments[segment] = element
self._nextSegmentIdx += 1

# Update indexes
if not element in self._segmentsForElement:
self._segmentsForElement[element] = set()
self._segmentsForElement[element].add(segment)

return segment


def destroySegment(self, segment):
"""
Destroys a segment.
@param segment (int) Segment index
"""
synapses = set(self.synapsesForSegment(segment))
for synapse in synapses:
self.destroySynapse(synapse)

element = self._segments[segment]
del self._segments[segment]

# Update indexes
self._segmentsForElement[element].remove(segment)


def createSynapse(self, segment, presynapticCell, permanence):
"""
Creates a new synapse on a segment.
@param segment (int) Segment index
@param presynapticCell (int) Source cell index
@param permanence (float) Initial permanence
@return (int) Synapse index
"""
self._validateSegment(segment)
self._validatePermanence(permanence)

# Add data
synapse = self._nextSynapseIdx
synapseData = self.SynapseData(segment, presynapticCell, permanence)
self._synapses[synapse] = synapseData
self._nextSynapseIdx += 1

# Update indexes
if not len(self.synapsesForSegment(segment)):
self._synapsesForSegment[segment] = set()
self._synapsesForSegment[segment].add(synapse)

self._synapsesForPresynapticCell[presynapticCell][synapse] = synapseData

return synapse


def destroySynapse(self, synapse):
"""
Destroys a synapse.
@param synapse (int) Synapse index
"""
data = self._synapses[synapse]
del self._synapses[synapse]

# Update indexes
self._synapsesForSegment[data.segment].remove(synapse)
del self._synapsesForPresynapticCell[data.presynapticCell][synapse]


def updateSynapsePermanence(self, synapse, permanence):
"""
Updates the permanence for a synapse.
@param synapse (int) Synapse index
@param permanence (float) New permanence
"""
self._validatePermanence(permanence)

data = self._synapses[synapse]
newData = self.SynapseData(data.segment,
data.presynapticCell,
permanence)
self._synapses[synapse] = newData

# Update indexes
self._synapsesForPresynapticCell[newData.presynapticCell][synapse] = newData


def numSegments(self):
"""
Returns the number of segments.
"""
return len(self._segments)


def numSynapses(self):
"""
Returns the number of synapses.
"""
return len(self._synapses)


def _validateElement(self, element):
"""
Raises an error if element index is invalid.
@param element (int) Cell or Column index
"""
if element >= self.numElements or element < 0:
raise IndexError("Invalid column or cell")


def _validateSegment(self, segment):
"""
Raises an error if segment index is invalid.
@param segment (int) Segment index
"""
if not segment in self._segments:
raise IndexError("Invalid segment")


def _validateSynapse(self, synapse):
"""
Raises an error if synapse index is invalid.
@param synapse (int) Synapse index
"""
if not synapse in self._synapses:
raise IndexError("Invalid synapse")


@staticmethod
def _validatePermanence(permanence):
"""
Raises an error if permanence is invalid.
@param permanence (float) Permanence
"""
if permanence < 0 or permanence > 1:
raise ValueError("Invalid permanence")
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def mmPrettyPrintConnections(self):
for cell in cells:
segmentDict = dict()

for seg in self.connections.segmentsForCell(cell):
for seg in self.connections.segmentsForElement(cell):
synapseList = []

for synapse in self.connections.synapsesForSegment(seg):
Expand Down
Loading

0 comments on commit fc89181

Please sign in to comment.